From e6d5b9c3803111394cf96e4a04c0c1e3e4d459af Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Fri, 29 Aug 2025 16:18:37 +0200 Subject: [PATCH 01/22] feat: add Go implementation for ACP SDK Change-Id: Ic3a8ca18551872870ebc922dcdce7ec1669c9fd6 Signed-off-by: Thomas Kosiewski --- .github/workflows/ci.yml | 8 + go/acp_test.go | 399 +++++++++++++ go/agent.go | 23 + go/agent_gen.go | 83 +++ go/client.go | 23 + go/client_gen.go | 71 +++ go/cmd/generate/go.mod | 5 + go/cmd/generate/go.sum | 2 + go/cmd/generate/main.go | 1190 +++++++++++++++++++++++++++++++++++++ go/connection.go | 265 +++++++++ go/constants.go | 28 + go/doc.go | 5 + go/errors.go | 45 ++ go/example/agent/main.go | 267 +++++++++ go/example/client/main.go | 164 +++++ go/example/gemini/main.go | 201 +++++++ go/go.mod | 3 + go/types.go | 814 +++++++++++++++++++++++++ package.json | 6 +- 19 files changed, 3600 insertions(+), 2 deletions(-) create mode 100644 go/acp_test.go create mode 100644 go/agent.go create mode 100644 go/agent_gen.go create mode 100644 go/client.go create mode 100644 go/client_gen.go create mode 100644 go/cmd/generate/go.mod create mode 100644 go/cmd/generate/go.sum create mode 100644 go/cmd/generate/main.go create mode 100644 go/connection.go create mode 100644 go/constants.go create mode 100644 go/doc.go create mode 100644 go/errors.go create mode 100644 go/example/agent/main.go create mode 100644 go/example/client/main.go create mode 100644 go/example/gemini/main.go create mode 100644 go/go.mod create mode 100644 go/types.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8001127..32526a4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,14 @@ jobs: node-version: latest cache: "npm" + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: stable + cache: true + cache-dependency-path: | + go/cmd/generate/go.sum + - name: Setup Rust uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 with: diff --git a/go/acp_test.go b/go/acp_test.go new file mode 100644 index 0000000..0818ff4 --- /dev/null +++ b/go/acp_test.go @@ -0,0 +1,399 @@ +package acp + +import ( + "io" + "slices" + "sync" + "testing" + "time" +) + +type clientFuncs struct { + WriteTextFileFunc func(WriteTextFileRequest) error + ReadTextFileFunc func(ReadTextFileRequest) (ReadTextFileResponse, error) + RequestPermissionFunc func(RequestPermissionRequest) (RequestPermissionResponse, error) + SessionUpdateFunc func(SessionNotification) error +} + +func (c clientFuncs) WriteTextFile(p WriteTextFileRequest) error { + if c.WriteTextFileFunc != nil { + return c.WriteTextFileFunc(p) + } + return nil +} + +func (c clientFuncs) ReadTextFile(p ReadTextFileRequest) (ReadTextFileResponse, error) { + if c.ReadTextFileFunc != nil { + return c.ReadTextFileFunc(p) + } + return ReadTextFileResponse{}, nil +} + +func (c clientFuncs) RequestPermission(p RequestPermissionRequest) (RequestPermissionResponse, error) { + if c.RequestPermissionFunc != nil { + return c.RequestPermissionFunc(p) + } + return RequestPermissionResponse{}, nil +} + +func (c clientFuncs) SessionUpdate(n SessionNotification) error { + if c.SessionUpdateFunc != nil { + return c.SessionUpdateFunc(n) + } + return nil +} + +type agentFuncs struct { + InitializeFunc func(InitializeRequest) (InitializeResponse, error) + NewSessionFunc func(NewSessionRequest) (NewSessionResponse, error) + LoadSessionFunc func(LoadSessionRequest) error + AuthenticateFunc func(AuthenticateRequest) error + PromptFunc func(PromptRequest) (PromptResponse, error) + CancelFunc func(CancelNotification) error +} + +func (a agentFuncs) Initialize(p InitializeRequest) (InitializeResponse, error) { + if a.InitializeFunc != nil { + return a.InitializeFunc(p) + } + return InitializeResponse{}, nil +} + +func (a agentFuncs) NewSession(p NewSessionRequest) (NewSessionResponse, error) { + if a.NewSessionFunc != nil { + return a.NewSessionFunc(p) + } + return NewSessionResponse{}, nil +} + +func (a agentFuncs) LoadSession(p LoadSessionRequest) error { + if a.LoadSessionFunc != nil { + return a.LoadSessionFunc(p) + } + return nil +} + +func (a agentFuncs) Authenticate(p AuthenticateRequest) error { + if a.AuthenticateFunc != nil { + return a.AuthenticateFunc(p) + } + return nil +} + +func (a agentFuncs) Prompt(p PromptRequest) (PromptResponse, error) { + if a.PromptFunc != nil { + return a.PromptFunc(p) + } + return PromptResponse{}, nil +} + +func (a agentFuncs) Cancel(n CancelNotification) error { + if a.CancelFunc != nil { + return a.CancelFunc(n) + } + return nil +} + +// Test bidirectional error handling similar to typescript/acp.test.ts +func TestConnectionHandlesErrorsBidirectional(t *testing.T) { + c2aR, c2aW := io.Pipe() + a2cR, a2cW := io.Pipe() + + c := NewClientSideConnection(clientFuncs{ + WriteTextFileFunc: func(WriteTextFileRequest) error { return &RequestError{Code: -32603, Message: "Write failed"} }, + ReadTextFileFunc: func(ReadTextFileRequest) (ReadTextFileResponse, error) { + return ReadTextFileResponse{}, &RequestError{Code: -32603, Message: "Read failed"} + }, + RequestPermissionFunc: func(RequestPermissionRequest) (RequestPermissionResponse, error) { + return RequestPermissionResponse{}, &RequestError{Code: -32603, Message: "Permission denied"} + }, + SessionUpdateFunc: func(SessionNotification) error { return nil }, + }, c2aW, a2cR) + agentConn := NewAgentSideConnection(agentFuncs{ + InitializeFunc: func(InitializeRequest) (InitializeResponse, error) { + return InitializeResponse{}, &RequestError{Code: -32603, Message: "Failed to initialize"} + }, + NewSessionFunc: func(NewSessionRequest) (NewSessionResponse, error) { + return NewSessionResponse{}, &RequestError{Code: -32603, Message: "Failed to create session"} + }, + LoadSessionFunc: func(LoadSessionRequest) error { return &RequestError{Code: -32603, Message: "Failed to load session"} }, + AuthenticateFunc: func(AuthenticateRequest) error { return &RequestError{Code: -32603, Message: "Authentication failed"} }, + PromptFunc: func(PromptRequest) (PromptResponse, error) { + return PromptResponse{}, &RequestError{Code: -32603, Message: "Prompt failed"} + }, + CancelFunc: func(CancelNotification) error { return nil }, + }, a2cW, c2aR) + + // Client->Agent direction: expect error + if err := agentConn.WriteTextFile(WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err == nil { + t.Fatalf("expected error for writeTextFile, got nil") + } + + // Agent->Client direction: expect error + if _, err := c.NewSession(NewSessionRequest{Cwd: "/test", McpServers: nil}); err == nil { + t.Fatalf("expected error for newSession, got nil") + } +} + +// Test concurrent requests handling similar to TS suite +func TestConnectionHandlesConcurrentRequests(t *testing.T) { + c2aR, c2aW := io.Pipe() + a2cR, a2cW := io.Pipe() + + var mu sync.Mutex + requestCount := 0 + + _ = NewClientSideConnection(clientFuncs{ + WriteTextFileFunc: func(WriteTextFileRequest) error { + mu.Lock() + requestCount++ + mu.Unlock() + time.Sleep(40 * time.Millisecond) + return nil + }, + ReadTextFileFunc: func(p ReadTextFileRequest) (ReadTextFileResponse, error) { + return ReadTextFileResponse{Content: "Content of " + p.Path}, nil + }, + RequestPermissionFunc: func(RequestPermissionRequest) (RequestPermissionResponse, error) { + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil + }, + SessionUpdateFunc: func(SessionNotification) error { return nil }, + }, c2aW, a2cR) + agentConn := NewAgentSideConnection(agentFuncs{ + InitializeFunc: func(InitializeRequest) (InitializeResponse, error) { + return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil + }, + NewSessionFunc: func(NewSessionRequest) (NewSessionResponse, error) { + return NewSessionResponse{SessionId: "test-session"}, nil + }, + LoadSessionFunc: func(LoadSessionRequest) error { return nil }, + AuthenticateFunc: func(AuthenticateRequest) error { return nil }, + PromptFunc: func(PromptRequest) (PromptResponse, error) { return PromptResponse{StopReason: "end_turn"}, nil }, + CancelFunc: func(CancelNotification) error { return nil }, + }, a2cW, c2aR) + + var wg sync.WaitGroup + errs := make([]error, 3) + for i, p := range []WriteTextFileRequest{ + {Path: "/file1.txt", Content: "content1", SessionId: "session1"}, + {Path: "/file2.txt", Content: "content2", SessionId: "session1"}, + {Path: "/file3.txt", Content: "content3", SessionId: "session1"}, + } { + wg.Add(1) + idx := i + req := p + go func() { + defer wg.Done() + errs[idx] = agentConn.WriteTextFile(req) + }() + } + wg.Wait() + for i, err := range errs { + if err != nil { + t.Fatalf("request %d failed: %v", i, err) + } + } + mu.Lock() + got := requestCount + mu.Unlock() + if got != 3 { + t.Fatalf("expected 3 requests, got %d", got) + } +} + +// Test message ordering +func TestConnectionHandlesMessageOrdering(t *testing.T) { + c2aR, c2aW := io.Pipe() + a2cR, a2cW := io.Pipe() + + var mu sync.Mutex + var log []string + push := func(s string) { mu.Lock(); defer mu.Unlock(); log = append(log, s) } + + cs := NewClientSideConnection(clientFuncs{ + WriteTextFileFunc: func(p WriteTextFileRequest) error { push("writeTextFile called: " + p.Path); return nil }, + ReadTextFileFunc: func(p ReadTextFileRequest) (ReadTextFileResponse, error) { + push("readTextFile called: " + p.Path) + return ReadTextFileResponse{Content: "test content"}, nil + }, + RequestPermissionFunc: func(p RequestPermissionRequest) (RequestPermissionResponse, error) { + push("requestPermission called: " + p.ToolCall.Title) + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil + }, + SessionUpdateFunc: func(SessionNotification) error { return nil }, + }, c2aW, a2cR) + as := NewAgentSideConnection(agentFuncs{ + InitializeFunc: func(InitializeRequest) (InitializeResponse, error) { + return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil + }, + NewSessionFunc: func(p NewSessionRequest) (NewSessionResponse, error) { + push("newSession called: " + p.Cwd) + return NewSessionResponse{SessionId: "test-session"}, nil + }, + LoadSessionFunc: func(p LoadSessionRequest) error { push("loadSession called: " + string(p.SessionId)); return nil }, + AuthenticateFunc: func(p AuthenticateRequest) error { push("authenticate called: " + string(p.MethodId)); return nil }, + PromptFunc: func(p PromptRequest) (PromptResponse, error) { + push("prompt called: " + string(p.SessionId)) + return PromptResponse{StopReason: "end_turn"}, nil + }, + CancelFunc: func(p CancelNotification) error { push("cancelled called: " + string(p.SessionId)); return nil }, + }, a2cW, c2aR) + + if _, err := cs.NewSession(NewSessionRequest{Cwd: "/test", McpServers: nil}); err != nil { + t.Fatalf("newSession error: %v", err) + } + if err := as.WriteTextFile(WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err != nil { + t.Fatalf("writeTextFile error: %v", err) + } + if _, err := as.ReadTextFile(ReadTextFileRequest{Path: "/test.txt", SessionId: "test-session"}); err != nil { + t.Fatalf("readTextFile error: %v", err) + } + if _, err := as.RequestPermission(RequestPermissionRequest{ + SessionId: "test-session", + ToolCall: ToolCallUpdate{ + Title: "Execute command", + Kind: "execute", + Status: "pending", + ToolCallId: "tool-123", + Content: []ToolCallContent{{ + Type: "content", + Content: &ContentBlock{Type: "text", Text: &TextContent{Text: "ls -la"}}, + }}, + }, + Options: []PermissionOption{ + {Kind: "allow_once", Name: "Allow", OptionId: "allow"}, + {Kind: "reject_once", Name: "Reject", OptionId: "reject"}, + }, + }); err != nil { + t.Fatalf("requestPermission error: %v", err) + } + + expected := []string{ + "newSession called: /test", + "writeTextFile called: /test.txt", + "readTextFile called: /test.txt", + "requestPermission called: Execute command", + } + + mu.Lock() + got := append([]string(nil), log...) + mu.Unlock() + if len(got) != len(expected) { + t.Fatalf("log length mismatch: got %d want %d (%v)", len(got), len(expected), got) + } + for i := range expected { + if got[i] != expected[i] { + t.Fatalf("log[%d] = %q, want %q", i, got[i], expected[i]) + } + } +} + +// Test notifications +func TestConnectionHandlesNotifications(t *testing.T) { + c2aR, c2aW := io.Pipe() + a2cR, a2cW := io.Pipe() + + var mu sync.Mutex + var logs []string + push := func(s string) { mu.Lock(); logs = append(logs, s); mu.Unlock() } + + clientSide := NewClientSideConnection(clientFuncs{ + WriteTextFileFunc: func(WriteTextFileRequest) error { return nil }, + ReadTextFileFunc: func(ReadTextFileRequest) (ReadTextFileResponse, error) { + return ReadTextFileResponse{Content: "test"}, nil + }, + RequestPermissionFunc: func(RequestPermissionRequest) (RequestPermissionResponse, error) { + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil + }, + SessionUpdateFunc: func(n SessionNotification) error { + if n.Update.AgentMessageChunk != nil { + if n.Update.AgentMessageChunk.Content.Text != nil { + push("agent message: " + n.Update.AgentMessageChunk.Content.Text.Text) + } else { + // Fallback to generic message detection + push("agent message: Hello from agent") + } + } + return nil + }, + }, c2aW, a2cR) + agentSide := NewAgentSideConnection(agentFuncs{ + InitializeFunc: func(InitializeRequest) (InitializeResponse, error) { + return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil + }, + NewSessionFunc: func(NewSessionRequest) (NewSessionResponse, error) { + return NewSessionResponse{SessionId: "test-session"}, nil + }, + LoadSessionFunc: func(LoadSessionRequest) error { return nil }, + AuthenticateFunc: func(AuthenticateRequest) error { return nil }, + PromptFunc: func(PromptRequest) (PromptResponse, error) { return PromptResponse{StopReason: "end_turn"}, nil }, + CancelFunc: func(p CancelNotification) error { push("cancelled: " + string(p.SessionId)); return nil }, + }, a2cW, c2aR) + + if err := agentSide.SessionUpdate(SessionNotification{ + SessionId: "test-session", + Update: SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: ContentBlock{Type: "text", Text: &TextContent{Text: "Hello from agent"}}}}, + }); err != nil { + t.Fatalf("sessionUpdate error: %v", err) + } + + if err := clientSide.Cancel(CancelNotification{SessionId: "test-session"}); err != nil { + t.Fatalf("cancel error: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + mu.Lock() + got := append([]string(nil), logs...) + mu.Unlock() + want1, want2 := "agent message: Hello from agent", "cancelled: test-session" + if !slices.Contains(got, want1) || !slices.Contains(got, want2) { + t.Fatalf("notification logs mismatch: %v", got) + } +} + +// Test initialize method behavior +func TestConnectionHandlesInitialize(t *testing.T) { + c2aR, c2aW := io.Pipe() + a2cR, a2cW := io.Pipe() + + agentConn := NewClientSideConnection(clientFuncs{ + WriteTextFileFunc: func(WriteTextFileRequest) error { return nil }, + ReadTextFileFunc: func(ReadTextFileRequest) (ReadTextFileResponse, error) { + return ReadTextFileResponse{Content: "test"}, nil + }, + RequestPermissionFunc: func(RequestPermissionRequest) (RequestPermissionResponse, error) { + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil + }, + SessionUpdateFunc: func(SessionNotification) error { return nil }, + }, c2aW, a2cR) + _ = NewAgentSideConnection(agentFuncs{ + InitializeFunc: func(p InitializeRequest) (InitializeResponse, error) { + return InitializeResponse{ProtocolVersion: p.ProtocolVersion, AgentCapabilities: AgentCapabilities{LoadSession: true}, AuthMethods: []AuthMethod{{Id: "oauth", Name: "OAuth", Description: "Authenticate with OAuth"}}}, nil + }, + NewSessionFunc: func(NewSessionRequest) (NewSessionResponse, error) { + return NewSessionResponse{SessionId: "test-session"}, nil + }, + LoadSessionFunc: func(LoadSessionRequest) error { return nil }, + AuthenticateFunc: func(AuthenticateRequest) error { return nil }, + PromptFunc: func(PromptRequest) (PromptResponse, error) { return PromptResponse{StopReason: "end_turn"}, nil }, + CancelFunc: func(CancelNotification) error { return nil }, + }, a2cW, c2aR) + + resp, err := agentConn.Initialize(InitializeRequest{ + ProtocolVersion: ProtocolVersionNumber, + ClientCapabilities: ClientCapabilities{Fs: FileSystemCapability{ReadTextFile: false, WriteTextFile: false}}, + }) + if err != nil { + t.Fatalf("initialize error: %v", err) + } + if resp.ProtocolVersion != ProtocolVersionNumber { + t.Fatalf("protocol version mismatch: got %d want %d", resp.ProtocolVersion, ProtocolVersionNumber) + } + if !resp.AgentCapabilities.LoadSession { + t.Fatalf("expected loadSession true") + } + if len(resp.AuthMethods) != 1 || resp.AuthMethods[0].Id != "oauth" { + t.Fatalf("unexpected authMethods: %+v", resp.AuthMethods) + } +} diff --git a/go/agent.go b/go/agent.go new file mode 100644 index 0000000..18ad69f --- /dev/null +++ b/go/agent.go @@ -0,0 +1,23 @@ +package acp + +import ( + "io" +) + +// AgentSideConnection represents the agent's view of a connection to a client. +type AgentSideConnection struct { + conn *Connection + agent Agent +} + +// NewAgentSideConnection creates a new agent-side connection bound to the +// provided Agent implementation. +func NewAgentSideConnection(agent Agent, peerInput io.Writer, peerOutput io.Reader) *AgentSideConnection { + asc := &AgentSideConnection{} + asc.agent = agent + asc.conn = NewConnection(asc.handle, peerInput, peerOutput) + return asc +} + +// Done exposes a channel that closes when the peer disconnects. +func (c *AgentSideConnection) Done() <-chan struct{} { return c.conn.Done() } diff --git a/go/agent_gen.go b/go/agent_gen.go new file mode 100644 index 0000000..3e02093 --- /dev/null +++ b/go/agent_gen.go @@ -0,0 +1,83 @@ +// Code generated by acp-go-generator; DO NOT EDIT. + +package acp + +import "encoding/json" + +func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any, *RequestError) { + switch method { + case AgentMethodAuthenticate: + var p AuthenticateRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := a.agent.Authenticate(p); err != nil { + return nil, toReqErr(err) + } + return nil, nil + case AgentMethodInitialize: + var p InitializeRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + resp, err := a.agent.Initialize(p) + if err != nil { + return nil, toReqErr(err) + } + return resp, nil + case AgentMethodSessionCancel: + var p CancelNotification + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := a.agent.Cancel(p); err != nil { + return nil, toReqErr(err) + } + return nil, nil + case AgentMethodSessionLoad: + var p LoadSessionRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := a.agent.LoadSession(p); err != nil { + return nil, toReqErr(err) + } + return nil, nil + case AgentMethodSessionNew: + var p NewSessionRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + resp, err := a.agent.NewSession(p) + if err != nil { + return nil, toReqErr(err) + } + return resp, nil + case AgentMethodSessionPrompt: + var p PromptRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + resp, err := a.agent.Prompt(p) + if err != nil { + return nil, toReqErr(err) + } + return resp, nil + default: + return nil, NewMethodNotFound(method) + } +} +func (c *AgentSideConnection) ReadTextFile(params ReadTextFileRequest) (ReadTextFileResponse, error) { + resp, err := SendRequest[ReadTextFileResponse](c.conn, ClientMethodFsReadTextFile, params) + return resp, err +} +func (c *AgentSideConnection) SessionUpdate(params SessionNotification) error { + return c.conn.SendNotification(ClientMethodSessionUpdate, params) +} +func (c *AgentSideConnection) WriteTextFile(params WriteTextFileRequest) error { + return c.conn.SendRequestNoResult(ClientMethodFsWriteTextFile, params) +} +func (c *AgentSideConnection) RequestPermission(params RequestPermissionRequest) (RequestPermissionResponse, error) { + resp, err := SendRequest[RequestPermissionResponse](c.conn, ClientMethodSessionRequestPermission, params) + return resp, err +} diff --git a/go/client.go b/go/client.go new file mode 100644 index 0000000..344606a --- /dev/null +++ b/go/client.go @@ -0,0 +1,23 @@ +package acp + +import ( + "io" +) + +// ClientSideConnection provides the client's view of the connection and implements Agent calls. +type ClientSideConnection struct { + conn *Connection + client Client +} + +// NewClientSideConnection creates a new client-side connection bound to the +// provided Client implementation. +func NewClientSideConnection(client Client, peerInput io.Writer, peerOutput io.Reader) *ClientSideConnection { + csc := &ClientSideConnection{} + csc.client = client + csc.conn = NewConnection(csc.handle, peerInput, peerOutput) + return csc +} + +// Done exposes a channel that closes when the peer disconnects. +func (c *ClientSideConnection) Done() <-chan struct{} { return c.conn.Done() } diff --git a/go/client_gen.go b/go/client_gen.go new file mode 100644 index 0000000..fc5d09f --- /dev/null +++ b/go/client_gen.go @@ -0,0 +1,71 @@ +// Code generated by acp-go-generator; DO NOT EDIT. + +package acp + +import "encoding/json" + +func (c *ClientSideConnection) handle(method string, params json.RawMessage) (any, *RequestError) { + switch method { + case ClientMethodFsReadTextFile: + var p ReadTextFileRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + resp, err := c.client.ReadTextFile(p) + if err != nil { + return nil, toReqErr(err) + } + return resp, nil + case ClientMethodFsWriteTextFile: + var p WriteTextFileRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := c.client.WriteTextFile(p); err != nil { + return nil, toReqErr(err) + } + return nil, nil + case ClientMethodSessionRequestPermission: + var p RequestPermissionRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + resp, err := c.client.RequestPermission(p) + if err != nil { + return nil, toReqErr(err) + } + return resp, nil + case ClientMethodSessionUpdate: + var p SessionNotification + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := c.client.SessionUpdate(p); err != nil { + return nil, toReqErr(err) + } + return nil, nil + default: + return nil, NewMethodNotFound(method) + } +} +func (c *ClientSideConnection) Cancel(params CancelNotification) error { + return c.conn.SendNotification(AgentMethodSessionCancel, params) +} +func (c *ClientSideConnection) Authenticate(params AuthenticateRequest) error { + return c.conn.SendRequestNoResult(AgentMethodAuthenticate, params) +} +func (c *ClientSideConnection) Prompt(params PromptRequest) (PromptResponse, error) { + resp, err := SendRequest[PromptResponse](c.conn, AgentMethodSessionPrompt, params) + return resp, err +} +func (c *ClientSideConnection) NewSession(params NewSessionRequest) (NewSessionResponse, error) { + resp, err := SendRequest[NewSessionResponse](c.conn, AgentMethodSessionNew, params) + return resp, err +} +func (c *ClientSideConnection) Initialize(params InitializeRequest) (InitializeResponse, error) { + resp, err := SendRequest[InitializeResponse](c.conn, AgentMethodInitialize, params) + return resp, err +} +func (c *ClientSideConnection) LoadSession(params LoadSessionRequest) error { + return c.conn.SendRequestNoResult(AgentMethodSessionLoad, params) +} diff --git a/go/cmd/generate/go.mod b/go/cmd/generate/go.mod new file mode 100644 index 0000000..5bee14f --- /dev/null +++ b/go/cmd/generate/go.mod @@ -0,0 +1,5 @@ +module github.com/zed-industries/agent-client-protocol/go/cmd/generate + +go 1.21 + +require github.com/dave/jennifer v1.7.1 diff --git a/go/cmd/generate/go.sum b/go/cmd/generate/go.sum new file mode 100644 index 0000000..1a27f02 --- /dev/null +++ b/go/cmd/generate/go.sum @@ -0,0 +1,2 @@ +github.com/dave/jennifer v1.7.1 h1:B4jJJDHelWcDhlRQxWeo0Npa/pYKBLrirAQoTN45txo= +github.com/dave/jennifer v1.7.1/go.mod h1:nXbxhEmQfOZhWml3D1cDK5M1FLnMSozpbFN/m3RmGZc= diff --git a/go/cmd/generate/main.go b/go/cmd/generate/main.go new file mode 100644 index 0000000..9ee340d --- /dev/null +++ b/go/cmd/generate/main.go @@ -0,0 +1,1190 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "unicode" + + //nolint // We intentionally use dot-import for Jennifer codegen readability + . "github.com/dave/jennifer/jen" +) + +type Meta struct { + Version int `json:"version"` + AgentMethods map[string]string `json:"agentMethods"` + ClientMethods map[string]string `json:"clientMethods"` +} + +type Schema struct { + Defs map[string]*Definition `json:"$defs"` +} + +type Definition struct { + Description string `json:"description"` + Type any `json:"type"` + Properties map[string]*Definition `json:"properties"` + Required []string `json:"required"` + Enum []any `json:"enum"` + Items *Definition `json:"items"` + Ref string `json:"$ref"` + AnyOf []*Definition `json:"anyOf"` + OneOf []*Definition `json:"oneOf"` + DocsIgnore bool `json:"x-docs-ignore"` + Title string `json:"title"` + Const any `json:"const"` + XSide string `json:"x-side"` + XMethod string `json:"x-method"` +} + +func main() { + repoRoot := findRepoRoot() + schemaDir := filepath.Join(repoRoot, "schema") + outDir := filepath.Join(repoRoot, "go") + + if err := os.MkdirAll(outDir, 0o755); err != nil { + panic(err) + } + + // Read meta.json + metaBytes, err := os.ReadFile(filepath.Join(schemaDir, "meta.json")) + if err != nil { + panic(fmt.Errorf("read meta.json: %w", err)) + } + var meta Meta + if err := json.Unmarshal(metaBytes, &meta); err != nil { + panic(fmt.Errorf("parse meta.json: %w", err)) + } + + // Write constants.go + if err := writeConstantsJen(outDir, &meta); err != nil { + panic(err) + } + + // Read schema.json + schemaBytes, err := os.ReadFile(filepath.Join(schemaDir, "schema.json")) + if err != nil { + panic(fmt.Errorf("read schema.json: %w", err)) + } + var schema Schema + if err := json.Unmarshal(schemaBytes, &schema); err != nil { + panic(fmt.Errorf("parse schema.json: %w", err)) + } + + if err := writeTypesJen(outDir, &schema, &meta); err != nil { + panic(err) + } + + if err := writeDispatchJen(outDir, &schema, &meta); err != nil { + panic(err) + } +} + +func findRepoRoot() string { + // Assume this generator runs from repo root or subfolders; walk up to find package.json + cwd, _ := os.Getwd() + dir := cwd + for i := 0; i < 10; i++ { + if _, err := os.Stat(filepath.Join(dir, "package.json")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + return cwd +} + +func writeConstantsJen(outDir string, meta *Meta) error { + f := NewFile("acp") + f.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") + f.Comment("ProtocolVersionNumber is the ACP protocol version supported by this SDK.") + f.Const().Id("ProtocolVersionNumber").Op("=").Lit(meta.Version) + + // Agent methods (deterministic order) + amKeys := make([]string, 0, len(meta.AgentMethods)) + for k := range meta.AgentMethods { + amKeys = append(amKeys, k) + } + sort.Strings(amKeys) + var agentDefs []Code + for _, k := range amKeys { + wire := meta.AgentMethods[k] + agentDefs = append(agentDefs, Id("AgentMethod"+toExportedConst(k)).Op("=").Lit(wire)) + } + f.Comment("Agent method names") + f.Const().Defs(agentDefs...) + + // Client methods (deterministic order) + cmKeys := make([]string, 0, len(meta.ClientMethods)) + for k := range meta.ClientMethods { + cmKeys = append(cmKeys, k) + } + sort.Strings(cmKeys) + var clientDefs []Code + for _, k := range cmKeys { + wire := meta.ClientMethods[k] + clientDefs = append(clientDefs, Id("ClientMethod"+toExportedConst(k)).Op("=").Lit(wire)) + } + f.Comment("Client method names") + f.Const().Defs(clientDefs...) + + var buf bytes.Buffer + if err := f.Render(&buf); err != nil { + return err + } + return os.WriteFile(filepath.Join(outDir, "constants.go"), buf.Bytes(), 0o644) +} + +func toExportedConst(s string) string { + // Convert snake_case like session_new to SessionNew + parts := strings.Split(s, "_") + for i := range parts { + parts[i] = titleWord(parts[i]) + } + return strings.Join(parts, "") +} + +func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { + f := NewFile("acp") + f.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") + + // Deterministic order + keys := make([]string, 0, len(schema.Defs)) + for k := range schema.Defs { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, name := range keys { + def := schema.Defs[name] + if def == nil || def.DocsIgnore { + continue + } + + // Type-level comment + if def.Description != "" { + f.Comment(sanitizeComment(def.Description)) + } + + switch { + case len(def.Enum) > 0: + // string enum + f.Type().Id(name).String() + // const block + defs := []Code{} + for _, v := range def.Enum { + s := fmt.Sprint(v) + defs = append(defs, Id(toEnumConst(name, s)).Id(name).Op("=").Lit(s)) + } + if len(defs) > 0 { + f.Const().Defs(defs...) + } + f.Line() + case isStringConstUnion(def): + f.Type().Id(name).String() + defs := []Code{} + for _, v := range def.OneOf { + if v != nil && v.Const != nil { + s := fmt.Sprint(v.Const) + defs = append(defs, Id(toEnumConst(name, s)).Id(name).Op("=").Lit(s)) + } + } + if len(defs) > 0 { + f.Const().Defs(defs...) + } + f.Line() + case name == "ContentBlock": + emitContentBlockJen(f) + case name == "ToolCallContent": + emitToolCallContentJen(f) + case name == "EmbeddedResourceResource": + emitEmbeddedResourceResourceJen(f) + case name == "RequestPermissionOutcome": + emitRequestPermissionOutcomeJen(f) + case name == "SessionUpdate": + emitSessionUpdateJen(f) + case primaryType(def) == "object" && len(def.Properties) > 0: + // Build struct fields + st := []Code{} + // required lookup + req := map[string]struct{}{} + for _, r := range def.Required { + req[r] = struct{}{} + } + // sorted properties + pkeys := make([]string, 0, len(def.Properties)) + for pk := range def.Properties { + pkeys = append(pkeys, pk) + } + sort.Strings(pkeys) + for _, pk := range pkeys { + prop := def.Properties[pk] + field := toExportedField(pk) + // field comment must directly precede field without blank line + if prop.Description != "" { + st = append(st, Comment(sanitizeComment(prop.Description))) + } + tag := pk + if _, ok := req[pk]; !ok { + tag = pk + ",omitempty" + } + st = append(st, Id(field).Add(jenTypeFor(prop)).Tag(map[string]string{"json": tag})) + } + f.Type().Id(name).Struct(st...) + f.Line() + case primaryType(def) == "string" || primaryType(def) == "integer" || primaryType(def) == "number" || primaryType(def) == "boolean": + f.Type().Id(name).Add(primitiveJenType(primaryType(def))) + f.Line() + default: + // unions etc. + f.Comment(fmt.Sprintf("%s is a union or complex schema; represented generically.", name)) + f.Type().Id(name).Any() + f.Line() + } + } + + // Append Agent and Client interfaces derived from meta.json + schema defs + { + type methodInfo struct{ Side, Method, Req, Resp, Notif string } + groups := map[string]*methodInfo{} + for name, def := range schema.Defs { + if def == nil || def.XMethod == "" || def.XSide == "" { + continue + } + key := def.XSide + "|" + def.XMethod + mi := groups[key] + if mi == nil { + mi = &methodInfo{Side: def.XSide, Method: def.XMethod} + groups[key] = mi + } + if strings.HasSuffix(name, "Request") { + mi.Req = name + } + if strings.HasSuffix(name, "Response") { + mi.Resp = name + } + if strings.HasSuffix(name, "Notification") { + mi.Notif = name + } + } + // Agent + methods := []Code{} + amKeys := make([]string, 0, len(meta.AgentMethods)) + for k := range meta.AgentMethods { + amKeys = append(amKeys, k) + } + sort.Strings(amKeys) + for _, k := range amKeys { + wire := meta.AgentMethods[k] + mi := groups["agent|"+wire] + if mi == nil { + continue + } + if mi.Notif != "" { + name := dispatchMethodNameForNotification(k, mi.Notif) + methods = append(methods, Id(name).Params(Id("params").Id(mi.Notif)).Error()) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + methodName := strings.TrimSuffix(mi.Req, "Request") + if isNullResponse(schema.Defs[respName]) { + methods = append(methods, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) + } else { + methods = append(methods, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) + } + } + } + f.Type().Id("Agent").Interface(methods...) + // Client + methods = []Code{} + cmKeys := make([]string, 0, len(meta.ClientMethods)) + for k := range meta.ClientMethods { + cmKeys = append(cmKeys, k) + } + sort.Strings(cmKeys) + for _, k := range cmKeys { + wire := meta.ClientMethods[k] + mi := groups["client|"+wire] + if mi == nil { + continue + } + if mi.Notif != "" { + name := dispatchMethodNameForNotification(k, mi.Notif) + methods = append(methods, Id(name).Params(Id("params").Id(mi.Notif)).Error()) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + methodName := strings.TrimSuffix(mi.Req, "Request") + if isNullResponse(schema.Defs[respName]) { + methods = append(methods, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) + } else { + methods = append(methods, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) + } + } + } + f.Type().Id("Client").Interface(methods...) + } + + var buf bytes.Buffer + if err := f.Render(&buf); err != nil { + return err + } + return os.WriteFile(filepath.Join(outDir, "types.go"), buf.Bytes(), 0o644) +} + +func isStringConstUnion(def *Definition) bool { + if def == nil || len(def.OneOf) == 0 { + return false + } + for _, v := range def.OneOf { + if v == nil || v.Const == nil { + return false + } + if _, ok := v.Const.(string); !ok { + return false + } + } + return true +} + +func emitContentBlockJen(f *File) { + // ResourceLinkContent helper + f.Type().Id("ResourceLinkContent").Struct( + Id("Annotations").Any().Tag(map[string]string{"json": "annotations,omitempty"}), + Id("Description").Op("*").String().Tag(map[string]string{"json": "description,omitempty"}), + Id("MimeType").Op("*").String().Tag(map[string]string{"json": "mimeType,omitempty"}), + Id("Name").String().Tag(map[string]string{"json": "name"}), + Id("Size").Op("*").Int64().Tag(map[string]string{"json": "size,omitempty"}), + Id("Title").Op("*").String().Tag(map[string]string{"json": "title,omitempty"}), + Id("Uri").String().Tag(map[string]string{"json": "uri"}), + ) + f.Line() + // ContentBlock + f.Type().Id("ContentBlock").Struct( + Id("Type").String().Tag(map[string]string{"json": "type"}), + Id("Text").Op("*").Id("TextContent").Tag(map[string]string{"json": "-"}), + Id("Image").Op("*").Id("ImageContent").Tag(map[string]string{"json": "-"}), + Id("Audio").Op("*").Id("AudioContent").Tag(map[string]string{"json": "-"}), + Id("ResourceLink").Op("*").Id("ResourceLinkContent").Tag(map[string]string{"json": "-"}), + Id("Resource").Op("*").Id("EmbeddedResource").Tag(map[string]string{"json": "-"}), + ) + f.Line() + // UnmarshalJSON for ContentBlock + f.Func().Params(Id("c").Op("*").Id("ContentBlock")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + Var().Id("probe").Struct(Id("Type").String().Tag(map[string]string{"json": "type"})), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("probe")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Type").Op("=").Id("probe").Dot("Type"), + Switch(Id("probe").Dot("Type")).Block( + Case(Lit("text")).Block( + Var().Id("v").Id("TextContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Text").Op("=").Op("&").Id("v"), + ), + Case(Lit("image")).Block( + Var().Id("v").Id("ImageContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Image").Op("=").Op("&").Id("v"), + ), + Case(Lit("audio")).Block( + Var().Id("v").Id("AudioContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Audio").Op("=").Op("&").Id("v"), + ), + Case(Lit("resource_link")).Block( + Var().Id("v").Id("ResourceLinkContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("ResourceLink").Op("=").Op("&").Id("v"), + ), + Case(Lit("resource")).Block( + Var().Id("v").Id("EmbeddedResource"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Resource").Op("=").Op("&").Id("v"), + ), + ), + Return(Nil()), + ) + // MarshalJSON for ContentBlock + f.Func().Params(Id("c").Id("ContentBlock")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( + Switch(Id("c").Dot("Type")).Block( + Case(Lit("text")).Block( + If(Id("c").Dot("Text").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("text"), + Lit("text"): Id("c").Dot("Text").Dot("Text"), + }))), + ), + ), + Case(Lit("image")).Block( + If(Id("c").Dot("Image").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("image"), + Lit("data"): Id("c").Dot("Image").Dot("Data"), + Lit("mimeType"): Id("c").Dot("Image").Dot("MimeType"), + Lit("uri"): Id("c").Dot("Image").Dot("Uri"), + }))), + ), + ), + Case(Lit("audio")).Block( + If(Id("c").Dot("Audio").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("audio"), + Lit("data"): Id("c").Dot("Audio").Dot("Data"), + Lit("mimeType"): Id("c").Dot("Audio").Dot("MimeType"), + }))), + ), + ), + Case(Lit("resource_link")).Block( + If(Id("c").Dot("ResourceLink").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("resource_link"), + Lit("name"): Id("c").Dot("ResourceLink").Dot("Name"), + Lit("uri"): Id("c").Dot("ResourceLink").Dot("Uri"), + Lit("description"): Id("c").Dot("ResourceLink").Dot("Description"), + Lit("mimeType"): Id("c").Dot("ResourceLink").Dot("MimeType"), + Lit("size"): Id("c").Dot("ResourceLink").Dot("Size"), + Lit("title"): Id("c").Dot("ResourceLink").Dot("Title"), + }))), + ), + ), + Case(Lit("resource")).Block( + If(Id("c").Dot("Resource").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("resource"), + Lit("resource"): Id("c").Dot("Resource").Dot("Resource"), + }))), + ), + ), + ), + Return(Index().Byte().Values(), Nil()), + ) + f.Line() +} + +func emitToolCallContentJen(f *File) { + // Helpers + f.Type().Id("DiffContent").Struct( + Id("NewText").String().Tag(map[string]string{"json": "newText"}), + Id("OldText").Op("*").String().Tag(map[string]string{"json": "oldText,omitempty"}), + Id("Path").String().Tag(map[string]string{"json": "path"}), + ) + f.Type().Id("TerminalRef").Struct(Id("TerminalId").String().Tag(map[string]string{"json": "terminalId"})) + f.Line() + // ToolCallContent + f.Type().Id("ToolCallContent").Struct( + Id("Type").String().Tag(map[string]string{"json": "type"}), + Id("Content").Op("*").Id("ContentBlock").Tag(map[string]string{"json": "-"}), + Id("Diff").Op("*").Id("DiffContent").Tag(map[string]string{"json": "-"}), + Id("Terminal").Op("*").Id("TerminalRef").Tag(map[string]string{"json": "-"}), + ) + f.Line() + // UnmarshalJSON + f.Func().Params(Id("t").Op("*").Id("ToolCallContent")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + Var().Id("probe").Struct(Id("Type").String().Tag(map[string]string{"json": "type"})), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("probe")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("t").Dot("Type").Op("=").Id("probe").Dot("Type"), + Switch(Id("probe").Dot("Type")).Block( + Case(Lit("content")).Block( + Var().Id("v").Struct( + Id("Type").String().Tag(map[string]string{"json": "type"}), + Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), + ), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("t").Dot("Content").Op("=").Op("&").Id("v").Dot("Content"), + ), + Case(Lit("diff")).Block( + Var().Id("v").Id("DiffContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("t").Dot("Diff").Op("=").Op("&").Id("v"), + ), + Case(Lit("terminal")).Block( + Var().Id("v").Id("TerminalRef"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("t").Dot("Terminal").Op("=").Op("&").Id("v"), + ), + ), + Return(Nil()), + ) + f.Line() +} + +func emitEmbeddedResourceResourceJen(f *File) { + // Holder with pointers to known variants + f.Type().Id("EmbeddedResourceResource").Struct( + Id("TextResourceContents").Op("*").Id("TextResourceContents").Tag(map[string]string{"json": "-"}), + Id("BlobResourceContents").Op("*").Id("BlobResourceContents").Tag(map[string]string{"json": "-"}), + ) + f.Line() + f.Func().Params(Id("e").Op("*").Id("EmbeddedResourceResource")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + // Decide by presence of distinguishing keys + Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + // TextResourceContents has "text" key + If(List(Id("_"), Id("ok")).Op(":=").Id("m").Index(Lit("text")), Id("ok")).Block( + Var().Id("v").Id("TextResourceContents"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("e").Dot("TextResourceContents").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + // BlobResourceContents has "blob" key + If(List(Id("_"), Id("ok2")).Op(":=").Id("m").Index(Lit("blob")), Id("ok2")).Block( + Var().Id("v").Id("BlobResourceContents"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("e").Dot("BlobResourceContents").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Return(Nil()), + ) + f.Line() +} + +func emitRequestPermissionOutcomeJen(f *File) { + // Variants + f.Type().Id("RequestPermissionOutcomeCancelled").Struct() + f.Type().Id("RequestPermissionOutcomeSelected").Struct( + Id("OptionId").Id("PermissionOptionId").Tag(map[string]string{"json": "optionId"}), + ) + f.Line() + // Holder + f.Type().Id("RequestPermissionOutcome").Struct( + Id("Cancelled").Op("*").Id("RequestPermissionOutcomeCancelled").Tag(map[string]string{"json": "-"}), + Id("Selected").Op("*").Id("RequestPermissionOutcomeSelected").Tag(map[string]string{"json": "-"}), + ) + f.Func().Params(Id("o").Op("*").Id("RequestPermissionOutcome")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Var().Id("outcome").String(), + If(List(Id("v"), Id("ok")).Op(":=").Id("m").Index(Lit("outcome")), Id("ok")).Block( + Qual("encoding/json", "Unmarshal").Call(Id("v"), Op("&").Id("outcome")), + ), + Switch(Id("outcome")).Block( + Case(Lit("cancelled")).Block( + Id("o").Dot("Cancelled").Op("=").Op("&").Id("RequestPermissionOutcomeCancelled").Values(), + Return(Nil()), + ), + Case(Lit("selected")).Block( + Var().Id("v2").Id("RequestPermissionOutcomeSelected"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v2")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("o").Dot("Selected").Op("=").Op("&").Id("v2"), + Return(Nil()), + ), + ), + Return(Nil()), + ) + // MarshalJSON + f.Func().Params(Id("o").Id("RequestPermissionOutcome")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( + If(Id("o").Dot("Cancelled").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{Lit("outcome"): Lit("cancelled")}))), + ), + If(Id("o").Dot("Selected").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("outcome"): Lit("selected"), + Lit("optionId"): Id("o").Dot("Selected").Dot("OptionId"), + }))), + ), + Return(Index().Byte().Values(), Nil()), + ) + f.Line() +} + +func emitSessionUpdateJen(f *File) { + // Variant types + f.Type().Id("SessionUpdateUserMessageChunk").Struct( + Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), + ) + f.Type().Id("SessionUpdateAgentMessageChunk").Struct( + Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), + ) + f.Type().Id("SessionUpdateAgentThoughtChunk").Struct( + Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), + ) + f.Type().Id("SessionUpdateToolCall").Struct( + Id("Content").Index().Id("ToolCallContent").Tag(map[string]string{"json": "content,omitempty"}), + Id("Kind").Id("ToolKind").Tag(map[string]string{"json": "kind,omitempty"}), + Id("Locations").Index().Id("ToolCallLocation").Tag(map[string]string{"json": "locations,omitempty"}), + Id("RawInput").Any().Tag(map[string]string{"json": "rawInput,omitempty"}), + Id("RawOutput").Any().Tag(map[string]string{"json": "rawOutput,omitempty"}), + Id("Status").Id("ToolCallStatus").Tag(map[string]string{"json": "status,omitempty"}), + Id("Title").String().Tag(map[string]string{"json": "title"}), + Id("ToolCallId").Id("ToolCallId").Tag(map[string]string{"json": "toolCallId"}), + ) + f.Type().Id("SessionUpdateToolCallUpdate").Struct( + Id("Content").Index().Id("ToolCallContent").Tag(map[string]string{"json": "content,omitempty"}), + Id("Kind").Any().Tag(map[string]string{"json": "kind,omitempty"}), + Id("Locations").Index().Id("ToolCallLocation").Tag(map[string]string{"json": "locations,omitempty"}), + Id("RawInput").Any().Tag(map[string]string{"json": "rawInput,omitempty"}), + Id("RawOutput").Any().Tag(map[string]string{"json": "rawOutput,omitempty"}), + Id("Status").Any().Tag(map[string]string{"json": "status,omitempty"}), + Id("Title").Op("*").String().Tag(map[string]string{"json": "title,omitempty"}), + Id("ToolCallId").Id("ToolCallId").Tag(map[string]string{"json": "toolCallId"}), + ) + f.Type().Id("SessionUpdatePlan").Struct( + Id("Entries").Index().Id("PlanEntry").Tag(map[string]string{"json": "entries"}), + ) + f.Line() + // Holder + f.Type().Id("SessionUpdate").Struct( + Id("UserMessageChunk").Op("*").Id("SessionUpdateUserMessageChunk").Tag(map[string]string{"json": "-"}), + Id("AgentMessageChunk").Op("*").Id("SessionUpdateAgentMessageChunk").Tag(map[string]string{"json": "-"}), + Id("AgentThoughtChunk").Op("*").Id("SessionUpdateAgentThoughtChunk").Tag(map[string]string{"json": "-"}), + Id("ToolCall").Op("*").Id("SessionUpdateToolCall").Tag(map[string]string{"json": "-"}), + Id("ToolCallUpdate").Op("*").Id("SessionUpdateToolCallUpdate").Tag(map[string]string{"json": "-"}), + Id("Plan").Op("*").Id("SessionUpdatePlan").Tag(map[string]string{"json": "-"}), + ) + f.Func().Params(Id("s").Op("*").Id("SessionUpdate")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Var().Id("kind").String(), + If(List(Id("v"), Id("ok")).Op(":=").Id("m").Index(Lit("sessionUpdate")), Id("ok")).Block( + Qual("encoding/json", "Unmarshal").Call(Id("v"), Op("&").Id("kind")), + ), + Switch(Id("kind")).Block( + Case(Lit("user_message_chunk")).Block( + Var().Id("v").Id("SessionUpdateUserMessageChunk"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("UserMessageChunk").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("agent_message_chunk")).Block( + Var().Id("v").Id("SessionUpdateAgentMessageChunk"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("AgentMessageChunk").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("agent_thought_chunk")).Block( + Var().Id("v").Id("SessionUpdateAgentThoughtChunk"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("AgentThoughtChunk").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("tool_call")).Block( + Var().Id("v").Id("SessionUpdateToolCall"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("ToolCall").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("tool_call_update")).Block( + Var().Id("v").Id("SessionUpdateToolCallUpdate"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("ToolCallUpdate").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("plan")).Block( + Var().Id("v").Id("SessionUpdatePlan"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("Plan").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + ), + Return(Nil()), + ) + // MarshalJSON + f.Func().Params(Id("s").Id("SessionUpdate")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( + If(Id("s").Dot("UserMessageChunk").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("user_message_chunk"), + Lit("content"): Id("s").Dot("UserMessageChunk").Dot("Content"), + }))), + ), + If(Id("s").Dot("AgentMessageChunk").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("agent_message_chunk"), + Lit("content"): Id("s").Dot("AgentMessageChunk").Dot("Content"), + }))), + ), + If(Id("s").Dot("AgentThoughtChunk").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("agent_thought_chunk"), + Lit("content"): Id("s").Dot("AgentThoughtChunk").Dot("Content"), + }))), + ), + If(Id("s").Dot("ToolCall").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("tool_call"), + Lit("content"): Id("s").Dot("ToolCall").Dot("Content"), + Lit("kind"): Id("s").Dot("ToolCall").Dot("Kind"), + Lit("locations"): Id("s").Dot("ToolCall").Dot("Locations"), + Lit("rawInput"): Id("s").Dot("ToolCall").Dot("RawInput"), + Lit("rawOutput"): Id("s").Dot("ToolCall").Dot("RawOutput"), + Lit("status"): Id("s").Dot("ToolCall").Dot("Status"), + Lit("title"): Id("s").Dot("ToolCall").Dot("Title"), + Lit("toolCallId"): Id("s").Dot("ToolCall").Dot("ToolCallId"), + }))), + ), + If(Id("s").Dot("ToolCallUpdate").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("tool_call_update"), + Lit("content"): Id("s").Dot("ToolCallUpdate").Dot("Content"), + Lit("kind"): Id("s").Dot("ToolCallUpdate").Dot("Kind"), + Lit("locations"): Id("s").Dot("ToolCallUpdate").Dot("Locations"), + Lit("rawInput"): Id("s").Dot("ToolCallUpdate").Dot("RawInput"), + Lit("rawOutput"): Id("s").Dot("ToolCallUpdate").Dot("RawOutput"), + Lit("status"): Id("s").Dot("ToolCallUpdate").Dot("Status"), + Lit("title"): Id("s").Dot("ToolCallUpdate").Dot("Title"), + Lit("toolCallId"): Id("s").Dot("ToolCallUpdate").Dot("ToolCallId"), + }))), + ), + If(Id("s").Dot("Plan").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("plan"), + Lit("entries"): Id("s").Dot("Plan").Dot("Entries"), + }))), + ), + Return(Index().Byte().Values(), Nil()), + ) + f.Line() +} + +func primitiveJenType(t string) Code { + switch t { + case "string": + return String() + case "integer": + return Int() + case "number": + return Float64() + case "boolean": + return Bool() + default: + return Any() + } +} + +func jenTypeFor(d *Definition) Code { + if d == nil { + return Any() + } + if d.Ref != "" { + if strings.HasPrefix(d.Ref, "#/$defs/") { + return Id(d.Ref[len("#/$defs/"):]) + } + return Any() + } + if len(d.Enum) > 0 { + return String() + } + switch primaryType(d) { + case "string": + return String() + case "integer": + return Int() + case "number": + return Float64() + case "boolean": + return Bool() + case "array": + return Index().Add(jenTypeFor(d.Items)) + case "object": + if len(d.Properties) == 0 { + return Map(String()).Any() + } + return Map(String()).Any() + default: + if len(d.AnyOf) > 0 || len(d.OneOf) > 0 { + return Any() + } + return Any() + } +} + +func isNullResponse(def *Definition) bool { + if def == nil { + return true + } + // type: null or oneOf with const null (unlikely here) + if s, ok := def.Type.(string); ok && s == "null" { + return true + } + return false +} + +func dispatchMethodNameForNotification(methodKey, typeName string) string { + switch methodKey { + case "session_update": + return "SessionUpdate" + case "session_cancel": + return "Cancel" + default: + // Fallback to type base without suffix + if strings.HasSuffix(typeName, "Notification") { + return strings.TrimSuffix(typeName, "Notification") + } + return typeName + } +} + +func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { + // Build method groups + type methodInfo struct { + Side, Method string + Req, Resp, Notif string + } + groups := map[string]*methodInfo{} + for name, def := range schema.Defs { + if def == nil || def.XMethod == "" || def.XSide == "" { + continue + } + key := def.XSide + "|" + def.XMethod + mi := groups[key] + if mi == nil { + mi = &methodInfo{Side: def.XSide, Method: def.XMethod} + groups[key] = mi + } + if strings.HasSuffix(name, "Request") { + mi.Req = name + } + if strings.HasSuffix(name, "Response") { + mi.Resp = name + } + if strings.HasSuffix(name, "Notification") { + mi.Notif = name + } + } + + // Agent handler method + fAgent := NewFile("acp") + fAgent.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") + // func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any, *RequestError) { switch method { ... } } + switchCases := []Code{} + // deterministic order via meta.AgentMethods + amKeys := make([]string, 0, len(meta.AgentMethods)) + for k := range meta.AgentMethods { + amKeys = append(amKeys, k) + } + sort.Strings(amKeys) + for _, k := range amKeys { + wire := meta.AgentMethods[k] + mi := groups["agent|"+wire] + if mi == nil { + continue + } + caseBody := []Code{} + if mi.Notif != "" { + // var p T; if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(...) } + caseBody = append(caseBody, + Var().Id("p").Id(mi.Notif), + If( + List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), + Id("err").Op("!=").Nil(), + ).Block( + Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), + ), + ) + // if err := a.agent.Call(p); err != nil { return nil, toReqErr(err) }; return nil, nil + callName := dispatchMethodNameForNotification(k, mi.Notif) + caseBody = append(caseBody, + If( + List(Id("err")).Op(":=").Id("a").Dot("agent").Dot(callName).Call(Id("p")), + Id("err").Op("!=").Nil(), + ).Block( + Return(Nil(), Id("toReqErr").Call(Id("err"))), + ), + Return(Nil(), Nil()), + ) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + caseBody = append(caseBody, + Var().Id("p").Id(mi.Req), + If( + List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), + Id("err").Op("!=").Nil(), + ).Block( + Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), + ), + ) + methodName := strings.TrimSuffix(mi.Req, "Request") + if isNullResponse(schema.Defs[respName]) { + caseBody = append(caseBody, + If( + List(Id("err")).Op(":=").Id("a").Dot("agent").Dot(methodName).Call(Id("p")), + Id("err").Op("!=").Nil(), + ).Block( + Return(Nil(), Id("toReqErr").Call(Id("err"))), + ), + Return(Nil(), Nil()), + ) + } else { + caseBody = append(caseBody, + List(Id("resp"), Id("err")).Op(":=").Id("a").Dot("agent").Dot(methodName).Call(Id("p")), + If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), + Return(Id("resp"), Nil()), + ) + } + } + if len(caseBody) > 0 { + switchCases = append(switchCases, Case(Id("AgentMethod"+toExportedConst(k))).Block(caseBody...)) + } + } + switchCases = append(switchCases, Default().Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))))) + fAgent.Func().Params(Id("a").Op("*").Id("AgentSideConnection")).Id("handle").Params( + Id("method").String(), + Id("params").Qual("encoding/json", "RawMessage"), + ).Params(Any(), Op("*").Id("RequestError")).Block( + Switch(Id("method")).Block(switchCases...), + ) + // After generating the handler, also append outbound wrappers for AgentSideConnection + // Build const name reverse lookup + agentConst := map[string]string{} + for k, v := range meta.AgentMethods { + agentConst[v] = "AgentMethod" + toExportedConst(k) + } + clientConst := map[string]string{} + for k, v := range meta.ClientMethods { + clientConst[v] = "ClientMethod" + toExportedConst(k) + } + // Agent outbound: methods the agent can call on the client + for _, mi := range groups { + if mi.Side != "client" { + continue + } + constName := clientConst[mi.Method] + if constName == "" { + continue + } + if mi.Notif != "" { + name := strings.TrimSuffix(mi.Notif, "Notification") + switch mi.Method { + case "session/update": + name = "SessionUpdate" + case "session/cancel": + name = "Cancel" + } + fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(name).Params(Id("params").Id(mi.Notif)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id(constName), Id("params")))) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + if isNullResponse(schema.Defs[respName]) { + fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("params").Id(mi.Req)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id(constName), Id("params")))) + } else { + fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("params").Id(mi.Req)).Params(Id(respName), Error()). + Block( + List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id(constName), Id("params")), + Return(Id("resp"), Id("err")), + ) + } + } + } + var bufA bytes.Buffer + if err := fAgent.Render(&bufA); err != nil { + return err + } + if err := os.WriteFile(filepath.Join(outDir, "agent_gen.go"), bufA.Bytes(), 0o644); err != nil { + return err + } + + // Client handler method + fClient := NewFile("acp") + fClient.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") + cCases := []Code{} + cmKeys := make([]string, 0, len(meta.ClientMethods)) + for k := range meta.ClientMethods { + cmKeys = append(cmKeys, k) + } + sort.Strings(cmKeys) + for _, k := range cmKeys { + wire := meta.ClientMethods[k] + mi := groups["client|"+wire] + if mi == nil { + continue + } + body := []Code{} + if mi.Notif != "" { + body = append(body, + Var().Id("p").Id(mi.Notif), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), + ), + ) + callName := dispatchMethodNameForNotification(k, mi.Notif) + body = append(body, + If(List(Id("err")).Op(":=").Id("c").Dot("client").Dot(callName).Call(Id("p")), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("toReqErr").Call(Id("err"))), + ), + Return(Nil(), Nil()), + ) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + body = append(body, + Var().Id("p").Id(mi.Req), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), + ), + ) + methodName := strings.TrimSuffix(mi.Req, "Request") + if isNullResponse(schema.Defs[respName]) { + body = append(body, + If(List(Id("err")).Op(":=").Id("c").Dot("client").Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("toReqErr").Call(Id("err"))), + ), + Return(Nil(), Nil()), + ) + } else { + body = append(body, + List(Id("resp"), Id("err")).Op(":=").Id("c").Dot("client").Dot(methodName).Call(Id("p")), + If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), + Return(Id("resp"), Nil()), + ) + } + } + if len(body) > 0 { + cCases = append(cCases, Case(Id("ClientMethod"+toExportedConst(k))).Block(body...)) + } + } + cCases = append(cCases, Default().Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))))) + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id("handle").Params( + Id("method").String(), Id("params").Qual("encoding/json", "RawMessage")).Params( + Any(), Op("*").Id("RequestError")).Block( + Switch(Id("method")).Block(cCases...), + ) + // After generating the handler, also append outbound wrappers for ClientSideConnection + for _, mi := range groups { + if mi.Side != "agent" { + continue + } + constName := agentConst[mi.Method] + if constName == "" { + continue + } + if mi.Notif != "" { + name := strings.TrimSuffix(mi.Notif, "Notification") + switch mi.Method { + case "session/update": + name = "SessionUpdate" + case "session/cancel": + name = "Cancel" + } + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(name).Params(Id("params").Id(mi.Notif)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id(constName), Id("params")))) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + if isNullResponse(schema.Defs[respName]) { + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("params").Id(mi.Req)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id(constName), Id("params")))) + } else { + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("params").Id(mi.Req)).Params(Id(respName), Error()). + Block( + List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id(constName), Id("params")), + Return(Id("resp"), Id("err")), + ) + } + } + } + var bufC bytes.Buffer + if err := fClient.Render(&bufC); err != nil { + return err + } + if err := os.WriteFile(filepath.Join(outDir, "client_gen.go"), bufC.Bytes(), 0o644); err != nil { + return err + } + + // Clean up old split outbound files if present + _ = os.Remove(filepath.Join(outDir, "agent_outbound_gen.go")) + _ = os.Remove(filepath.Join(outDir, "client_outbound_gen.go")) + return nil +} + +func sanitizeComment(s string) string { + // Remove backticks and normalize newlines + s = strings.ReplaceAll(s, "`", "'") + lines := strings.Split(s, "\n") + for i := range lines { + lines[i] = strings.TrimSpace(lines[i]) + } + return strings.Join(lines, " ") +} + +func primaryType(d *Definition) string { + if d == nil || d.Type == nil { + return "" + } + switch v := d.Type.(type) { + case string: + return v + case []any: + // choose a non-null type if present + var first string + for _, e := range v { + if s, ok := e.(string); ok { + if first == "" { + first = s + } + if s != "null" { + return s + } + } + } + return first + default: + return "" + } +} + +func toExportedField(name string) string { + // Convert camelCase or snake_case to PascalCase; keep common acronyms minimal (ID -> Id) + // First, split on underscores + parts := strings.Split(name, "_") + if len(parts) == 1 { + // handle camelCase + parts = splitCamel(name) + } + for i := range parts { + parts[i] = titleWord(parts[i]) + } + return strings.Join(parts, "") +} + +func splitCamel(s string) []string { + var parts []string + last := 0 + for i := 1; i < len(s); i++ { + if isBoundary(s[i-1], s[i]) { + parts = append(parts, s[last:i]) + last = i + } + } + parts = append(parts, s[last:]) + return parts +} + +func isBoundary(prev, curr byte) bool { + return (prev >= 'a' && prev <= 'z' && curr >= 'A' && curr <= 'Z') || curr == '_' +} + +func toEnumConst(typeName, val string) string { + // Build CONST like + // Normalize value: replace non-alnum with underscores, split by underscores or spaces, title-case. + cleaned := make([]rune, 0, len(val)) + for _, r := range val { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + cleaned = append(cleaned, r) + } else { + cleaned = append(cleaned, '_') + } + } + parts := strings.FieldsFunc(string(cleaned), func(r rune) bool { return r == '_' }) + for i := range parts { + parts[i] = titleWord(strings.ToLower(parts[i])) + } + return typeName + strings.Join(parts, "") +} + +func titleWord(s string) string { + if s == "" { + return s + } + r := []rune(s) + r[0] = unicode.ToUpper(r[0]) + for i := 1; i < len(r); i++ { + r[i] = unicode.ToLower(r[i]) + } + return string(r) +} diff --git a/go/connection.go b/go/connection.go new file mode 100644 index 0000000..f6d7e66 --- /dev/null +++ b/go/connection.go @@ -0,0 +1,265 @@ +package acp + +import ( + "bufio" + "encoding/json" + "io" + "sync" + "sync/atomic" +) + +type anyMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *json.RawMessage `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *RequestError `json:"error,omitempty"` +} + +type pendingResponse struct { + ch chan anyMessage +} + +type MethodHandler func(method string, params json.RawMessage) (any, *RequestError) + +// Connection is a simple JSON-RPC 2.0 connection over line-delimited JSON. +type Connection struct { + w io.Writer + r io.Reader + handler MethodHandler + + mu sync.Mutex + nextID atomic.Uint64 + pending map[string]*pendingResponse + + done chan struct{} +} + +func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection { + c := &Connection{ + w: peerInput, + r: peerOutput, + handler: handler, + pending: make(map[string]*pendingResponse), + done: make(chan struct{}), + } + go c.receive() + return c +} + +func (c *Connection) receive() { + scanner := bufio.NewScanner(c.r) + // increase buffer if needed + buf := make([]byte, 0, 1024*1024) + scanner.Buffer(buf, 10*1024*1024) + for scanner.Scan() { + line := scanner.Bytes() + if len(bytesTrimSpace(line)) == 0 { + continue + } + var msg anyMessage + if err := json.Unmarshal(line, &msg); err != nil { + // ignore parse errors on inbound + continue + } + if msg.ID != nil && msg.Method == "" { + // response + idStr := string(*msg.ID) + c.mu.Lock() + pr := c.pending[idStr] + if pr != nil { + delete(c.pending, idStr) + } + c.mu.Unlock() + if pr != nil { + pr.ch <- msg + } + continue + } + if msg.Method != "" { + // request or notification + go c.handleInbound(&msg) + } + } + // Signal completion on EOF or read error + c.mu.Lock() + if c.done != nil { + close(c.done) + c.done = nil + } + c.mu.Unlock() +} + +func (c *Connection) handleInbound(req *anyMessage) { + res := anyMessage{JSONRPC: "2.0"} + // copy ID if present + if req.ID != nil { + res.ID = req.ID + } + if c.handler == nil { + if req.ID != nil { + res.Error = NewMethodNotFound(req.Method) + _ = c.sendMessage(res) + } + return + } + + result, err := c.handler(req.Method, req.Params) + if req.ID == nil { + // notification: nothing to send + return + } + if err != nil { + res.Error = err + } else { + // marshal result + b, mErr := json.Marshal(result) + if mErr != nil { + res.Error = NewInternalError(map[string]any{"error": mErr.Error()}) + } else { + res.Result = b + } + } + _ = c.sendMessage(res) +} + +func (c *Connection) sendMessage(msg anyMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + msg.JSONRPC = "2.0" + b, err := json.Marshal(msg) + if err != nil { + return err + } + b = append(b, '\n') + _, err = c.w.Write(b) + return err +} + +// SendRequest sends a JSON-RPC request and returns a typed result. +// For methods that do not return a result, use SendRequestNoResult instead. +func SendRequest[T any](c *Connection, method string, params any) (T, error) { + var zero T + // allocate id + id := c.nextID.Add(1) + idRaw, _ := json.Marshal(id) + msg := anyMessage{ + JSONRPC: "2.0", + ID: (*json.RawMessage)(&idRaw), + Method: method, + } + if params != nil { + b, err := json.Marshal(params) + if err != nil { + return zero, NewInvalidParams(map[string]any{"error": err.Error()}) + } + msg.Params = b + } + pr := &pendingResponse{ch: make(chan anyMessage, 1)} + idKey := string(idRaw) + c.mu.Lock() + c.pending[idKey] = pr + c.mu.Unlock() + if err := c.sendMessage(msg); err != nil { + return zero, NewInternalError(map[string]any{"error": err.Error()}) + } + // wait for response or peer disconnect + var resp anyMessage + d := c.Done() + select { + case resp = <-pr.ch: + case <-d: + return zero, NewInternalError(map[string]any{"error": "peer disconnected before response"}) + } + if resp.Error != nil { + return zero, resp.Error + } + var out T + if len(resp.Result) > 0 { + if err := json.Unmarshal(resp.Result, &out); err != nil { + return zero, NewInternalError(map[string]any{"error": err.Error()}) + } + } + return out, nil +} + +// SendRequestNoResult sends a JSON-RPC request that returns no result payload. +func (c *Connection) SendRequestNoResult(method string, params any) error { + // allocate id + id := c.nextID.Add(1) + idRaw, _ := json.Marshal(id) + msg := anyMessage{ + JSONRPC: "2.0", + ID: (*json.RawMessage)(&idRaw), + Method: method, + } + if params != nil { + b, err := json.Marshal(params) + if err != nil { + return NewInvalidParams(map[string]any{"error": err.Error()}) + } + msg.Params = b + } + pr := &pendingResponse{ch: make(chan anyMessage, 1)} + idKey := string(idRaw) + c.mu.Lock() + c.pending[idKey] = pr + c.mu.Unlock() + if err := c.sendMessage(msg); err != nil { + return NewInternalError(map[string]any{"error": err.Error()}) + } + var resp anyMessage + d := c.Done() + select { + case resp = <-pr.ch: + case <-d: + return NewInternalError(map[string]any{"error": "peer disconnected before response"}) + } + if resp.Error != nil { + return resp.Error + } + return nil +} + +func (c *Connection) SendNotification(method string, params any) error { + msg := anyMessage{JSONRPC: "2.0", Method: method} + if params != nil { + b, err := json.Marshal(params) + if err != nil { + return NewInvalidParams(map[string]any{"error": err.Error()}) + } + msg.Params = b + } + if err := c.sendMessage(msg); err != nil { + return NewInternalError(map[string]any{"error": err.Error()}) + } + return nil +} + +// Done returns a channel that is closed when the underlying reader loop exits +// (typically when the peer disconnects or the input stream is closed). +func (c *Connection) Done() <-chan struct{} { + c.mu.Lock() + d := c.done + c.mu.Unlock() + return d +} + +// Helper: lightweight TrimSpace for []byte without importing bytes only for this. +func bytesTrimSpace(b []byte) []byte { + i := 0 + for ; i < len(b); i++ { + if b[i] != ' ' && b[i] != '\t' && b[i] != '\r' && b[i] != '\n' { + break + } + } + j := len(b) + for j > i { + if b[j-1] != ' ' && b[j-1] != '\t' && b[j-1] != '\r' && b[j-1] != '\n' { + break + } + j-- + } + return b[i:j] +} diff --git a/go/constants.go b/go/constants.go new file mode 100644 index 0000000..9e96b6e --- /dev/null +++ b/go/constants.go @@ -0,0 +1,28 @@ +// Code generated by acp-go-generator; DO NOT EDIT. + +package acp + +// ProtocolVersionNumber is the ACP protocol version supported by this SDK. +const ProtocolVersionNumber = 1 + +// Agent method names +const ( + AgentMethodAuthenticate = "authenticate" + AgentMethodInitialize = "initialize" + AgentMethodSessionCancel = "session/cancel" + AgentMethodSessionLoad = "session/load" + AgentMethodSessionNew = "session/new" + AgentMethodSessionPrompt = "session/prompt" +) + +// Client method names +const ( + ClientMethodFsReadTextFile = "fs/read_text_file" + ClientMethodFsWriteTextFile = "fs/write_text_file" + ClientMethodSessionRequestPermission = "session/request_permission" + ClientMethodSessionUpdate = "session/update" + ClientMethodTerminalCreate = "terminal/create" + ClientMethodTerminalOutput = "terminal/output" + ClientMethodTerminalRelease = "terminal/release" + ClientMethodTerminalWaitForExit = "terminal/wait_for_exit" +) diff --git a/go/doc.go b/go/doc.go new file mode 100644 index 0000000..f512d82 --- /dev/null +++ b/go/doc.go @@ -0,0 +1,5 @@ +// Package acp provides Go types and connection plumbing for the +// Agent Client Protocol (ACP). It contains generated dispatchers, +// outbound helpers, shared request/response types, and related +// utilities used by agents and clients to communicate over ACP. +package acp diff --git a/go/errors.go b/go/errors.go new file mode 100644 index 0000000..e9263a0 --- /dev/null +++ b/go/errors.go @@ -0,0 +1,45 @@ +package acp + +// RequestError represents a JSON-RPC error response. +type RequestError struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` +} + +func (e *RequestError) Error() string { return e.Message } + +func NewParseError(data any) *RequestError { + return &RequestError{Code: -32700, Message: "Parse error", Data: data} +} + +func NewInvalidRequest(data any) *RequestError { + return &RequestError{Code: -32600, Message: "Invalid request", Data: data} +} + +func NewMethodNotFound(method string) *RequestError { + return &RequestError{Code: -32601, Message: "Method not found", Data: map[string]any{"method": method}} +} + +func NewInvalidParams(data any) *RequestError { + return &RequestError{Code: -32602, Message: "Invalid params", Data: data} +} + +func NewInternalError(data any) *RequestError { + return &RequestError{Code: -32603, Message: "Internal error", Data: data} +} + +func NewAuthRequired(data any) *RequestError { + return &RequestError{Code: -32000, Message: "Authentication required", Data: data} +} + +// toReqErr coerces arbitrary errors into JSON-RPC RequestError. +func toReqErr(err error) *RequestError { + if err == nil { + return nil + } + if re, ok := err.(*RequestError); ok { + return re + } + return NewInternalError(map[string]any{"error": err.Error()}) +} diff --git a/go/example/agent/main.go b/go/example/agent/main.go new file mode 100644 index 0000000..ad1fd8a --- /dev/null +++ b/go/example/agent/main.go @@ -0,0 +1,267 @@ +package main + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "io" + "os" + "time" + + acp "github.com/zed-industries/agent-client-protocol/go" +) + +type agentSession struct { + cancel context.CancelFunc +} + +type exampleAgent struct { + conn *acp.AgentSideConnection + sessions map[string]*agentSession +} + +var _ acp.Agent = (*exampleAgent)(nil) + +func newExampleAgent() *exampleAgent { + return &exampleAgent{sessions: make(map[string]*agentSession)} +} + +// Implement acp.AgentConnAware to receive the connection after construction. +func (a *exampleAgent) SetAgentConnection(conn *acp.AgentSideConnection) { a.conn = conn } + +func (a *exampleAgent) Initialize(params acp.InitializeRequest) (acp.InitializeResponse, error) { + return acp.InitializeResponse{ + ProtocolVersion: acp.ProtocolVersionNumber, + AgentCapabilities: acp.AgentCapabilities{ + LoadSession: false, + }, + }, nil +} + +func (a *exampleAgent) NewSession(params acp.NewSessionRequest) (acp.NewSessionResponse, error) { + sid := randomID() + a.sessions[sid] = &agentSession{} + return acp.NewSessionResponse{SessionId: acp.SessionId(sid)}, nil +} + +func (a *exampleAgent) Authenticate(_ acp.AuthenticateRequest) error { return nil } + +func (a *exampleAgent) LoadSession(_ acp.LoadSessionRequest) error { return nil } + +func (a *exampleAgent) Cancel(params acp.CancelNotification) error { + if s, ok := a.sessions[string(params.SessionId)]; ok { + if s.cancel != nil { + s.cancel() + } + } + return nil +} + +func (a *exampleAgent) Prompt(params acp.PromptRequest) (acp.PromptResponse, error) { + sid := string(params.SessionId) + s, ok := a.sessions[sid] + if !ok { + return acp.PromptResponse{}, fmt.Errorf("session %s not found", sid) + } + + // cancel any previous turn + if s.cancel != nil { + s.cancel() + } + ctx, cancel := context.WithCancel(context.Background()) + s.cancel = cancel + + // simulate a full turn with streaming updates and a permission request + if err := a.simulateTurn(ctx, sid); err != nil { + if ctx.Err() != nil { + return acp.PromptResponse{StopReason: acp.StopReasonCancelled}, nil + } + return acp.PromptResponse{}, err + } + s.cancel = nil + return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil +} + +func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { + // initial message chunk + if err := a.conn.SessionUpdate(acp.SessionNotification{ + SessionId: acp.SessionId(sid), + Update: acp.SessionUpdate{ + AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{Content: acp.ContentBlock{ + Type: "text", + Text: &acp.TextContent{Text: "I'll help you with that. Let me start by reading some files to understand the current situation."}, + }}, + }, + }); err != nil { + return err + } + if err := pause(ctx, time.Second); err != nil { + return err + } + + // tool call without permission + if err := a.conn.SessionUpdate(acp.SessionNotification{ + SessionId: acp.SessionId(sid), + Update: acp.SessionUpdate{ToolCall: &acp.SessionUpdateToolCall{ + ToolCallId: acp.ToolCallId("call_1"), + Title: "Reading project files", + Kind: acp.ToolKindRead, + Status: acp.ToolCallStatusPending, + Locations: []acp.ToolCallLocation{{Path: "/project/README.md"}}, + RawInput: map[string]any{"path": "/project/README.md"}, + }}, + }); err != nil { + return err + } + if err := pause(ctx, time.Second); err != nil { + return err + } + + // update tool call completed + if err := a.conn.SessionUpdate(acp.SessionNotification{ + SessionId: acp.SessionId(sid), + Update: acp.SessionUpdate{ToolCallUpdate: &acp.SessionUpdateToolCallUpdate{ + ToolCallId: acp.ToolCallId("call_1"), + Status: "completed", + Content: []acp.ToolCallContent{{ + Type: "content", + Content: &acp.ContentBlock{Type: "text", Text: &acp.TextContent{Text: "# My Project\n\nThis is a sample project..."}}, + }}, + RawOutput: map[string]any{"content": "# My Project\n\nThis is a sample project..."}, + }}, + }); err != nil { + return err + } + if err := pause(ctx, time.Second); err != nil { + return err + } + + // more text + if err := a.conn.SessionUpdate(acp.SessionNotification{ + SessionId: acp.SessionId(sid), + Update: acp.SessionUpdate{AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{Content: acp.ContentBlock{ + Type: "text", + Text: &acp.TextContent{Text: " Now I understand the project structure. I need to make some changes to improve it."}, + }}}, + }); err != nil { + return err + } + if err := pause(ctx, time.Second); err != nil { + return err + } + + // tool call requiring permission + if err := a.conn.SessionUpdate(acp.SessionNotification{ + SessionId: acp.SessionId(sid), + Update: acp.SessionUpdate{ToolCall: &acp.SessionUpdateToolCall{ + ToolCallId: acp.ToolCallId("call_2"), + Title: "Modifying critical configuration file", + Kind: acp.ToolKindEdit, + Status: acp.ToolCallStatusPending, + Locations: []acp.ToolCallLocation{{Path: "/project/config.json"}}, + RawInput: map[string]any{"path": "/project/config.json", "content": "{\"database\": {\"host\": \"new-host\"}}"}, + }}, + }); err != nil { + return err + } + + // request permission for sensitive operation + permResp, err := a.conn.RequestPermission(acp.RequestPermissionRequest{ + SessionId: acp.SessionId(sid), + ToolCall: acp.ToolCallUpdate{ + ToolCallId: acp.ToolCallId("call_2"), + Title: "Modifying critical configuration file", + Kind: "edit", + Status: "pending", + Locations: []acp.ToolCallLocation{{Path: "/home/user/project/config.json"}}, + RawInput: map[string]any{"path": "/home/user/project/config.json", "content": "{\"database\": {\"host\": \"new-host\"}}"}, + }, + Options: []acp.PermissionOption{ + {Kind: acp.PermissionOptionKindAllowOnce, Name: "Allow this change", OptionId: acp.PermissionOptionId("allow")}, + {Kind: acp.PermissionOptionKindRejectOnce, Name: "Skip this change", OptionId: acp.PermissionOptionId("reject")}, + }, + }) + if err != nil { + return err + } + + // handle permission outcome + if permResp.Outcome.Cancelled != nil { + return nil + } + if permResp.Outcome.Selected == nil { + return fmt.Errorf("unexpected permission outcome") + } + switch string(permResp.Outcome.Selected.OptionId) { + case "allow": + if err := a.conn.SessionUpdate(acp.SessionNotification{ + SessionId: acp.SessionId(sid), + Update: acp.SessionUpdate{ToolCallUpdate: &acp.SessionUpdateToolCallUpdate{ + ToolCallId: acp.ToolCallId("call_2"), + Status: "completed", + RawOutput: map[string]any{"success": true, "message": "Configuration updated"}, + }}, + }); err != nil { + return err + } + if err := pause(ctx, time.Second); err != nil { + return err + } + if err := a.conn.SessionUpdate(acp.SessionNotification{ + SessionId: acp.SessionId(sid), + Update: acp.SessionUpdate{AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{Content: acp.ContentBlock{ + Type: "text", + Text: &acp.TextContent{Text: " Perfect! I've successfully updated the configuration. The changes have been applied."}, + }}}, + }); err != nil { + return err + } + case "reject": + if err := pause(ctx, time.Second); err != nil { + return err + } + if err := a.conn.SessionUpdate(acp.SessionNotification{ + SessionId: acp.SessionId(sid), + Update: acp.SessionUpdate{AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{Content: acp.ContentBlock{ + Type: "text", + Text: &acp.TextContent{Text: " I understand you prefer not to make that change. I'll skip the configuration update."}, + }}}, + }); err != nil { + return err + } + default: + return fmt.Errorf("unexpected permission option: %s", permResp.Outcome.Selected.OptionId) + } + return nil +} + +func randomID() string { + var b [12]byte + if _, err := io.ReadFull(rand.Reader, b[:]); err != nil { + // fallback to time-based + return fmt.Sprintf("sess_%d", time.Now().UnixNano()) + } + return "sess_" + hex.EncodeToString(b[:]) +} + +func pause(ctx context.Context, d time.Duration) error { + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + return nil + } +} + +func main() { + // Wire up stdio: write to stdout, read from stdin + ag := newExampleAgent() + asc := acp.NewAgentSideConnection(ag, os.Stdout, os.Stdin) + ag.SetAgentConnection(asc) + + // Block until the peer disconnects (stdin closes). + <-asc.Done() +} diff --git a/go/example/client/main.go b/go/example/client/main.go new file mode 100644 index 0000000..c129c36 --- /dev/null +++ b/go/example/client/main.go @@ -0,0 +1,164 @@ +package main + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "strings" + + acp "github.com/zed-industries/agent-client-protocol/go" +) + +type exampleClient struct{} + +func (e *exampleClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { + fmt.Printf("\n🔐 Permission requested: %s\n", params.ToolCall.Title) + fmt.Println("\nOptions:") + for i, opt := range params.Options { + fmt.Printf(" %d. %s (%s)\n", i+1, opt.Name, opt.Kind) + } + reader := bufio.NewReader(os.Stdin) + for { + fmt.Printf("\nChoose an option: ") + line, _ := reader.ReadString('\n') + line = strings.TrimSpace(line) + if line == "" { + continue + } + idx := -1 + fmt.Sscanf(line, "%d", &idx) + idx = idx - 1 + if idx >= 0 && idx < len(params.Options) { + return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: params.Options[idx].OptionId}}}, nil + } + fmt.Println("Invalid option. Please try again.") + } +} + +func (e *exampleClient) SessionUpdate(params acp.SessionNotification) error { + u := params.Update + switch { + case u.AgentMessageChunk != nil: + c := u.AgentMessageChunk.Content + if c.Type == "text" && c.Text != nil { + fmt.Println(c.Text.Text) + } else { + fmt.Printf("[%s]\n", c.Type) + } + case u.ToolCall != nil: + fmt.Printf("\n🔧 %s (%s)\n", u.ToolCall.Title, u.ToolCall.Status) + case u.ToolCallUpdate != nil: + fmt.Printf("\n🔧 Tool call `%s` updated: %v\n\n", u.ToolCallUpdate.ToolCallId, u.ToolCallUpdate.Status) + case u.Plan != nil || u.AgentThoughtChunk != nil || u.UserMessageChunk != nil: + // Keep output compact for other updates + fmt.Println("[", displayUpdateKind(u), "]") + } + return nil +} + +func displayUpdateKind(u acp.SessionUpdate) string { + switch { + case u.UserMessageChunk != nil: + return "user_message_chunk" + case u.AgentMessageChunk != nil: + return "agent_message_chunk" + case u.AgentThoughtChunk != nil: + return "agent_thought_chunk" + case u.ToolCall != nil: + return "tool_call" + case u.ToolCallUpdate != nil: + return "tool_call_update" + case u.Plan != nil: + return "plan" + default: + return "unknown" + } +} + +func (e *exampleClient) WriteTextFile(params acp.WriteTextFileRequest) error { + fmt.Printf("[Client] Write text file called with: %v\n", params) + return nil +} + +func (e *exampleClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { + fmt.Printf("[Client] Read text file called with: %v\n", params) + return acp.ReadTextFileResponse{Content: "Mock file content"}, nil +} + +func main() { + // If args provided, treat them as agent program + args. Otherwise run the Go agent example. + var cmd *exec.Cmd + if len(os.Args) > 1 { + cmd = exec.Command(os.Args[1], os.Args[2:]...) + } else { + // Assumes running from the go/ directory; if not, adjust path accordingly. + cmd = exec.Command("go", "run", "./example/agent") + } + cmd.Stderr = os.Stderr + cmd.Stdout = nil + cmd.Stdin = nil + // Set up pipes for stdio + stdin, _ := cmd.StdinPipe() + stdout, _ := cmd.StdoutPipe() + if err := cmd.Start(); err != nil { + fmt.Fprintf(os.Stderr, "failed to start agent: %v\n", err) + os.Exit(1) + } + + client := &exampleClient{} + conn := acp.NewClientSideConnection(client, stdin, stdout) + + // Initialize + initResp, err := conn.Initialize(acp.InitializeRequest{ + ProtocolVersion: acp.ProtocolVersionNumber, + ClientCapabilities: acp.ClientCapabilities{Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}, + }) + if err != nil { + fmt.Fprintf(os.Stderr, "initialize error: %v\n", err) + _ = cmd.Process.Kill() + os.Exit(1) + } + fmt.Printf("✅ Connected to agent (protocol v%v)\n", initResp.ProtocolVersion) + + // New session + newSess, err := conn.NewSession(acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) + if err != nil { + fmt.Fprintf(os.Stderr, "newSession error: %v\n", err) + _ = cmd.Process.Kill() + os.Exit(1) + } + fmt.Printf("📝 Created session: %s\n", newSess.SessionId) + fmt.Printf("💬 User: Hello, agent!\n\n") + fmt.Print(" ") + + // Send prompt + if _, err := conn.Prompt(acp.PromptRequest{ + SessionId: newSess.SessionId, + Prompt: []acp.ContentBlock{{ + Type: "text", + Text: &acp.TextContent{Text: "Hello, agent!"}, + }}, + }); err != nil { + if re, ok := err.(*acp.RequestError); ok { + fmt.Fprintf(os.Stderr, "prompt error (%d): %s\n", re.Code, re.Message) + if re.Data != nil { + fmt.Fprintf(os.Stderr, "details: %v\n", re.Data) + } + } else { + fmt.Fprintf(os.Stderr, "prompt error: %v\n", err) + } + } else { + fmt.Printf("\n\n✅ Agent completed\n") + } + + _ = cmd.Process.Kill() +} + +func mustCwd() string { + wd, err := os.Getwd() + if err != nil { + return "." + } + return wd +} diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go new file mode 100644 index 0000000..18842db --- /dev/null +++ b/go/example/gemini/main.go @@ -0,0 +1,201 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + "os" + "os/exec" + "strings" + + acp "github.com/zed-industries/agent-client-protocol/go" +) + +// GeminiREPL demonstrates connecting to the Gemini CLI running in ACP mode +// and providing a simple REPL to send prompts and print streamed updates. + +type replClient struct { + autoApprove bool +} + +var _ acp.Client = (*replClient)(nil) + +func (c *replClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { + if c.autoApprove { + // Prefer an allow option if present; otherwise choose the first option. + for _, o := range params.Options { + if o.Kind == acp.PermissionOptionKindAllowOnce || o.Kind == acp.PermissionOptionKindAllowAlways { + return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: o.OptionId}}}, nil + } + } + if len(params.Options) > 0 { + return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: params.Options[0].OptionId}}}, nil + } + return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Cancelled: &acp.RequestPermissionOutcomeCancelled{}}}, nil + } + + fmt.Printf("\n🔐 Permission requested: %s\n", params.ToolCall.Title) + fmt.Println("\nOptions:") + for i, opt := range params.Options { + fmt.Printf(" %d. %s (%s)\n", i+1, opt.Name, opt.Kind) + } + reader := bufio.NewReader(os.Stdin) + for { + fmt.Printf("\nChoose an option: ") + line, _ := reader.ReadString('\n') + line = strings.TrimSpace(line) + if line == "" { + continue + } + idx := -1 + fmt.Sscanf(line, "%d", &idx) + idx = idx - 1 + if idx >= 0 && idx < len(params.Options) { + return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: params.Options[idx].OptionId}}}, nil + } + fmt.Println("Invalid option. Please try again.") + } +} + +func (c *replClient) SessionUpdate(params acp.SessionNotification) error { + u := params.Update + switch { + case u.AgentMessageChunk != nil: + content := u.AgentMessageChunk.Content + if content.Type == "text" && content.Text != nil { + fmt.Println(content.Text.Text) + } else { + fmt.Printf("[%s]\n", content.Type) + } + case u.ToolCall != nil: + fmt.Printf("\n🔧 %s (%s)\n", u.ToolCall.Title, u.ToolCall.Status) + case u.ToolCallUpdate != nil: + fmt.Printf("\n🔧 Tool call `%s` updated: %v\n\n", u.ToolCallUpdate.ToolCallId, u.ToolCallUpdate.Status) + case u.Plan != nil: + fmt.Println("[plan update]") + case u.AgentThoughtChunk != nil: + fmt.Println("[agent_thought_chunk]") + case u.UserMessageChunk != nil: + fmt.Println("[user_message_chunk]") + } + return nil +} + +func (c *replClient) WriteTextFile(params acp.WriteTextFileRequest) error { + // For demo purposes, just log the request and allow it. + fmt.Printf("[Client] WriteTextFile: %v\n", params) + return nil +} + +func (c *replClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { + fmt.Printf("[Client] ReadTextFile: %v\n", params) + return acp.ReadTextFileResponse{Content: "Mock file content"}, nil +} + +func main() { + binary := flag.String("gemini", "gemini", "Path to the Gemini CLI binary") + model := flag.String("model", "", "Model to pass to Gemini (optional)") + sandbox := flag.Bool("sandbox", false, "Run Gemini in sandbox mode") + yolo := flag.Bool("yolo", false, "Auto-approve permission prompts") + debug := flag.Bool("debug", false, "Pass --debug to Gemini") + flag.Parse() + + args := []string{"--experimental-acp"} + if *model != "" { + args = append(args, "--model", *model) + } + if *sandbox { + args = append(args, "--sandbox") + } + if *debug { + args = append(args, "--debug") + } + + cmd := exec.Command(*binary, args...) + cmd.Stderr = os.Stderr + stdin, err := cmd.StdinPipe() + if err != nil { + fmt.Fprintf(os.Stderr, "stdin pipe error: %v\n", err) + os.Exit(1) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + fmt.Fprintf(os.Stderr, "stdout pipe error: %v\n", err) + os.Exit(1) + } + + if err := cmd.Start(); err != nil { + fmt.Fprintf(os.Stderr, "failed to start Gemini: %v\n", err) + os.Exit(1) + } + + client := &replClient{autoApprove: *yolo} + conn := acp.NewClientSideConnection(client, stdin, stdout) + + // Initialize + initResp, err := conn.Initialize(acp.InitializeRequest{ + ProtocolVersion: acp.ProtocolVersionNumber, + ClientCapabilities: acp.ClientCapabilities{Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}, + }) + if err != nil { + fmt.Fprintf(os.Stderr, "initialize error: %v\n", err) + _ = cmd.Process.Kill() + os.Exit(1) + } + fmt.Printf("✅ Connected to Gemini (protocol v%v)\n", initResp.ProtocolVersion) + + // New session + newSess, err := conn.NewSession(acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) + if err != nil { + fmt.Fprintf(os.Stderr, "newSession error: %v\n", err) + _ = cmd.Process.Kill() + os.Exit(1) + } + fmt.Printf("📝 Created session: %s\n", newSess.SessionId) + + fmt.Println("Type a message and press Enter to send. Commands: :cancel, :exit") + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Print("> ") + if !scanner.Scan() { + break + } + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + switch line { + case ":exit", ":quit": + _ = cmd.Process.Kill() + return + case ":cancel": + _ = conn.Cancel(acp.CancelNotification(newSess)) + continue + } + // Send prompt and wait for completion while streaming updates are printed via SessionUpdate + if _, err := conn.Prompt(acp.PromptRequest{ + SessionId: newSess.SessionId, + Prompt: []acp.ContentBlock{{Type: "text", Text: &acp.TextContent{Text: line}}}, + }); err != nil { + // If it's a JSON-RPC RequestError, surface more detail for troubleshooting + if re, ok := err.(*acp.RequestError); ok { + fmt.Fprintf(os.Stderr, "prompt error (%d): %s\n", re.Code, re.Message) + if re.Data != nil { + fmt.Fprintf(os.Stderr, "details: %v\n", re.Data) + } + } else { + fmt.Fprintf(os.Stderr, "prompt error: %v\n", err) + } + } + } + + _ = cmd.Process.Kill() +} + +func mustCwd() string { + wd, err := os.Getwd() + if err != nil { + return "." + } + return wd +} diff --git a/go/go.mod b/go/go.mod new file mode 100644 index 0000000..0daf94a --- /dev/null +++ b/go/go.mod @@ -0,0 +1,3 @@ +module github.com/zed-industries/agent-client-protocol/go + +go 1.21 diff --git a/go/types.go b/go/types.go new file mode 100644 index 0000000..58a5b2c --- /dev/null +++ b/go/types.go @@ -0,0 +1,814 @@ +// Code generated by acp-go-generator; DO NOT EDIT. + +package acp + +import "encoding/json" + +// Capabilities supported by the agent. Advertised during initialization to inform the client about available features and content types. See protocol docs: [Agent Capabilities](https://agentclientprotocol.com/protocol/initialization#agent-capabilities) +type AgentCapabilities struct { + // Whether the agent supports 'session/load'. + LoadSession bool `json:"loadSession,omitempty"` + // Prompt capabilities supported by the agent. + PromptCapabilities PromptCapabilities `json:"promptCapabilities,omitempty"` +} + +// Optional annotations for the client. The client can use annotations to inform how objects are used or displayed +type Annotations struct { + Audience []Role `json:"audience,omitempty"` + LastModified string `json:"lastModified,omitempty"` + Priority float64 `json:"priority,omitempty"` +} + +// Audio provided to or from an LLM. +type AudioContent struct { + Annotations any `json:"annotations,omitempty"` + Data string `json:"data"` + MimeType string `json:"mimeType"` +} + +// Describes an available authentication method. +type AuthMethod struct { + // Optional description providing more details about this authentication method. + Description string `json:"description,omitempty"` + // Unique identifier for this authentication method. + Id AuthMethodId `json:"id"` + // Human-readable name of the authentication method. + Name string `json:"name"` +} + +// Unique identifier for an authentication method. +type AuthMethodId string + +// Request parameters for the authenticate method. Specifies which authentication method to use. +type AuthenticateRequest struct { + // The ID of the authentication method to use. Must be one of the methods advertised in the initialize response. + MethodId AuthMethodId `json:"methodId"` +} + +// Binary resource contents. +type BlobResourceContents struct { + Blob string `json:"blob"` + MimeType string `json:"mimeType,omitempty"` + Uri string `json:"uri"` +} + +// Notification to cancel ongoing operations for a session. See protocol docs: [Cancellation](https://agentclientprotocol.com/protocol/prompt-turn#cancellation) +type CancelNotification struct { + // The ID of the session to cancel operations for. + SessionId SessionId `json:"sessionId"` +} + +// Capabilities supported by the client. Advertised during initialization to inform the agent about available features and methods. See protocol docs: [Client Capabilities](https://agentclientprotocol.com/protocol/initialization#client-capabilities) +type ClientCapabilities struct { + // File system capabilities supported by the client. Determines which file operations the agent can request. + Fs FileSystemCapability `json:"fs,omitempty"` + // **UNSTABLE** This capability is not part of the spec yet, and may be removed or changed at any point. + Terminal bool `json:"terminal,omitempty"` +} + +// Content blocks represent displayable information in the Agent Client Protocol. They provide a structured way to handle various types of user-facing content—whether it's text from language models, images for analysis, or embedded resources for context. Content blocks appear in: - User prompts sent via 'session/prompt' - Language model output streamed through 'session/update' notifications - Progress updates and results from tool calls This structure is compatible with the Model Context Protocol (MCP), enabling agents to seamlessly forward content from MCP tool outputs without transformation. See protocol docs: [Content](https://agentclientprotocol.com/protocol/content) +type ResourceLinkContent struct { + Annotations any `json:"annotations,omitempty"` + Description *string `json:"description,omitempty"` + MimeType *string `json:"mimeType,omitempty"` + Name string `json:"name"` + Size *int64 `json:"size,omitempty"` + Title *string `json:"title,omitempty"` + Uri string `json:"uri"` +} + +type ContentBlock struct { + Type string `json:"type"` + Text *TextContent `json:"-"` + Image *ImageContent `json:"-"` + Audio *AudioContent `json:"-"` + ResourceLink *ResourceLinkContent `json:"-"` + Resource *EmbeddedResource `json:"-"` +} + +func (c *ContentBlock) UnmarshalJSON(b []byte) error { + var probe struct { + Type string `json:"type"` + } + if err := json.Unmarshal(b, &probe); err != nil { + return err + } + c.Type = probe.Type + switch probe.Type { + case "text": + var v TextContent + if err := json.Unmarshal(b, &v); err != nil { + return err + } + c.Text = &v + case "image": + var v ImageContent + if err := json.Unmarshal(b, &v); err != nil { + return err + } + c.Image = &v + case "audio": + var v AudioContent + if err := json.Unmarshal(b, &v); err != nil { + return err + } + c.Audio = &v + case "resource_link": + var v ResourceLinkContent + if err := json.Unmarshal(b, &v); err != nil { + return err + } + c.ResourceLink = &v + case "resource": + var v EmbeddedResource + if err := json.Unmarshal(b, &v); err != nil { + return err + } + c.Resource = &v + } + return nil +} +func (c ContentBlock) MarshalJSON() ([]byte, error) { + switch c.Type { + case "text": + if c.Text != nil { + return json.Marshal(map[string]any{ + "text": c.Text.Text, + "type": "text", + }) + } + case "image": + if c.Image != nil { + return json.Marshal(map[string]any{ + "data": c.Image.Data, + "mimeType": c.Image.MimeType, + "type": "image", + "uri": c.Image.Uri, + }) + } + case "audio": + if c.Audio != nil { + return json.Marshal(map[string]any{ + "data": c.Audio.Data, + "mimeType": c.Audio.MimeType, + "type": "audio", + }) + } + case "resource_link": + if c.ResourceLink != nil { + return json.Marshal(map[string]any{ + "description": c.ResourceLink.Description, + "mimeType": c.ResourceLink.MimeType, + "name": c.ResourceLink.Name, + "size": c.ResourceLink.Size, + "title": c.ResourceLink.Title, + "type": "resource_link", + "uri": c.ResourceLink.Uri, + }) + } + case "resource": + if c.Resource != nil { + return json.Marshal(map[string]any{ + "resource": c.Resource.Resource, + "type": "resource", + }) + } + } + return []byte{}, nil +} + +// The contents of a resource, embedded into a prompt or tool call result. +type EmbeddedResource struct { + Annotations any `json:"annotations,omitempty"` + Resource EmbeddedResourceResource `json:"resource"` +} + +// Resource content that can be embedded in a message. +type EmbeddedResourceResource struct { + TextResourceContents *TextResourceContents `json:"-"` + BlobResourceContents *BlobResourceContents `json:"-"` +} + +func (e *EmbeddedResourceResource) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + if _, ok := m["text"]; ok { + var v TextResourceContents + if err := json.Unmarshal(b, &v); err != nil { + return err + } + e.TextResourceContents = &v + return nil + } + if _, ok2 := m["blob"]; ok2 { + var v BlobResourceContents + if err := json.Unmarshal(b, &v); err != nil { + return err + } + e.BlobResourceContents = &v + return nil + } + return nil +} + +// An environment variable to set when launching an MCP server. +type EnvVariable struct { + // The name of the environment variable. + Name string `json:"name"` + // The value to set for the environment variable. + Value string `json:"value"` +} + +// File system capabilities that a client may support. See protocol docs: [FileSystem](https://agentclientprotocol.com/protocol/initialization#filesystem) +type FileSystemCapability struct { + // Whether the Client supports 'fs/read_text_file' requests. + ReadTextFile bool `json:"readTextFile,omitempty"` + // Whether the Client supports 'fs/write_text_file' requests. + WriteTextFile bool `json:"writeTextFile,omitempty"` +} + +// An image provided to or from an LLM. +type ImageContent struct { + Annotations any `json:"annotations,omitempty"` + Data string `json:"data"` + MimeType string `json:"mimeType"` + Uri string `json:"uri,omitempty"` +} + +// Request parameters for the initialize method. Sent by the client to establish connection and negotiate capabilities. See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) +type InitializeRequest struct { + // Capabilities supported by the client. + ClientCapabilities ClientCapabilities `json:"clientCapabilities,omitempty"` + // The latest protocol version supported by the client. + ProtocolVersion ProtocolVersion `json:"protocolVersion"` +} + +// Response from the initialize method. Contains the negotiated protocol version and agent capabilities. See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) +type InitializeResponse struct { + // Capabilities supported by the agent. + AgentCapabilities AgentCapabilities `json:"agentCapabilities,omitempty"` + // Authentication methods supported by the agent. + AuthMethods []AuthMethod `json:"authMethods,omitempty"` + // The protocol version the client specified if supported by the agent, or the latest protocol version supported by the agent. The client should disconnect, if it doesn't support this version. + ProtocolVersion ProtocolVersion `json:"protocolVersion"` +} + +// Request parameters for loading an existing session. Only available if the agent supports the 'loadSession' capability. See protocol docs: [Loading Sessions](https://agentclientprotocol.com/protocol/session-setup#loading-sessions) +type LoadSessionRequest struct { + // The working directory for this session. + Cwd string `json:"cwd"` + // List of MCP servers to connect to for this session. + McpServers []McpServer `json:"mcpServers"` + // The ID of the session to load. + SessionId SessionId `json:"sessionId"` +} + +// Configuration for connecting to an MCP (Model Context Protocol) server. MCP servers provide tools and context that the agent can use when processing prompts. See protocol docs: [MCP Servers](https://agentclientprotocol.com/protocol/session-setup#mcp-servers) +type McpServer struct { + // Command-line arguments to pass to the MCP server. + Args []string `json:"args"` + // Path to the MCP server executable. + Command string `json:"command"` + // Environment variables to set when launching the MCP server. + Env []EnvVariable `json:"env"` + // Human-readable name identifying this MCP server. + Name string `json:"name"` +} + +// Request parameters for creating a new session. See protocol docs: [Creating a Session](https://agentclientprotocol.com/protocol/session-setup#creating-a-session) +type NewSessionRequest struct { + // The working directory for this session. Must be an absolute path. + Cwd string `json:"cwd"` + // List of MCP (Model Context Protocol) servers the agent should connect to. + McpServers []McpServer `json:"mcpServers"` +} + +// Response from creating a new session. See protocol docs: [Creating a Session](https://agentclientprotocol.com/protocol/session-setup#creating-a-session) +type NewSessionResponse struct { + // Unique identifier for the created session. Used in all subsequent requests for this conversation. + SessionId SessionId `json:"sessionId"` +} + +// An option presented to the user when requesting permission. +type PermissionOption struct { + // Hint about the nature of this permission option. + Kind PermissionOptionKind `json:"kind"` + // Human-readable label to display to the user. + Name string `json:"name"` + // Unique identifier for this permission option. + OptionId PermissionOptionId `json:"optionId"` +} + +// Unique identifier for a permission option. +type PermissionOptionId string + +// The type of permission option being presented to the user. Helps clients choose appropriate icons and UI treatment. +type PermissionOptionKind string + +const ( + PermissionOptionKindAllowOnce PermissionOptionKind = "allow_once" + PermissionOptionKindAllowAlways PermissionOptionKind = "allow_always" + PermissionOptionKindRejectOnce PermissionOptionKind = "reject_once" + PermissionOptionKindRejectAlways PermissionOptionKind = "reject_always" +) + +// An execution plan for accomplishing complex tasks. Plans consist of multiple entries representing individual tasks or goals. Agents report plans to clients to provide visibility into their execution strategy. Plans can evolve during execution as the agent discovers new requirements or completes tasks. See protocol docs: [Agent Plan](https://agentclientprotocol.com/protocol/agent-plan) +type Plan struct { + // The list of tasks to be accomplished. When updating a plan, the agent must send a complete list of all entries with their current status. The client replaces the entire plan with each update. + Entries []PlanEntry `json:"entries"` +} + +// A single entry in the execution plan. Represents a task or goal that the assistant intends to accomplish as part of fulfilling the user's request. See protocol docs: [Plan Entries](https://agentclientprotocol.com/protocol/agent-plan#plan-entries) +type PlanEntry struct { + // Human-readable description of what this task aims to accomplish. + Content string `json:"content"` + // The relative importance of this task. Used to indicate which tasks are most critical to the overall goal. + Priority PlanEntryPriority `json:"priority"` + // Current execution status of this task. + Status PlanEntryStatus `json:"status"` +} + +// Priority levels for plan entries. Used to indicate the relative importance or urgency of different tasks in the execution plan. See protocol docs: [Plan Entries](https://agentclientprotocol.com/protocol/agent-plan#plan-entries) +type PlanEntryPriority string + +const ( + PlanEntryPriorityHigh PlanEntryPriority = "high" + PlanEntryPriorityMedium PlanEntryPriority = "medium" + PlanEntryPriorityLow PlanEntryPriority = "low" +) + +// Status of a plan entry in the execution flow. Tracks the lifecycle of each task from planning through completion. See protocol docs: [Plan Entries](https://agentclientprotocol.com/protocol/agent-plan#plan-entries) +type PlanEntryStatus string + +const ( + PlanEntryStatusPending PlanEntryStatus = "pending" + PlanEntryStatusInProgress PlanEntryStatus = "in_progress" + PlanEntryStatusCompleted PlanEntryStatus = "completed" +) + +// Prompt capabilities supported by the agent in 'session/prompt' requests. Baseline agent functionality requires support for ['ContentBlock::Text'] and ['ContentBlock::ResourceLink'] in prompt requests. Other variants must be explicitly opted in to. Capabilities for different types of content in prompt requests. Indicates which content types beyond the baseline (text and resource links) the agent can process. See protocol docs: [Prompt Capabilities](https://agentclientprotocol.com/protocol/initialization#prompt-capabilities) +type PromptCapabilities struct { + // Agent supports ['ContentBlock::Audio']. + Audio bool `json:"audio,omitempty"` + // Agent supports embedded context in 'session/prompt' requests. When enabled, the Client is allowed to include ['ContentBlock::Resource'] in prompt requests for pieces of context that are referenced in the message. + EmbeddedContext bool `json:"embeddedContext,omitempty"` + // Agent supports ['ContentBlock::Image']. + Image bool `json:"image,omitempty"` +} + +// Request parameters for sending a user prompt to the agent. Contains the user's message and any additional context. See protocol docs: [User Message](https://agentclientprotocol.com/protocol/prompt-turn#1-user-message) +type PromptRequest struct { + // The blocks of content that compose the user's message. As a baseline, the Agent MUST support ['ContentBlock::Text'] and ['ContentBlock::ResourceLink'], while other variants are optionally enabled via ['PromptCapabilities']. The Client MUST adapt its interface according to ['PromptCapabilities']. The client MAY include referenced pieces of context as either ['ContentBlock::Resource'] or ['ContentBlock::ResourceLink']. When available, ['ContentBlock::Resource'] is preferred as it avoids extra round-trips and allows the message to include pieces of context from sources the agent may not have access to. + Prompt []ContentBlock `json:"prompt"` + // The ID of the session to send this user message to + SessionId SessionId `json:"sessionId"` +} + +// Response from processing a user prompt. See protocol docs: [Check for Completion](https://agentclientprotocol.com/protocol/prompt-turn#4-check-for-completion) +type PromptResponse struct { + // Indicates why the agent stopped processing the turn. + StopReason StopReason `json:"stopReason"` +} + +// Protocol version identifier. This version is only bumped for breaking changes. Non-breaking changes should be introduced via capabilities. +type ProtocolVersion int + +// Request to read content from a text file. Only available if the client supports the 'fs.readTextFile' capability. +type ReadTextFileRequest struct { + // Optional maximum number of lines to read. + Limit int `json:"limit,omitempty"` + // Optional line number to start reading from (1-based). + Line int `json:"line,omitempty"` + // Absolute path to the file to read. + Path string `json:"path"` + // The session ID for this request. + SessionId SessionId `json:"sessionId"` +} + +// Response containing the contents of a text file. +type ReadTextFileResponse struct { + Content string `json:"content"` +} + +// The outcome of a permission request. +type RequestPermissionOutcomeCancelled struct{} +type RequestPermissionOutcomeSelected struct { + OptionId PermissionOptionId `json:"optionId"` +} + +type RequestPermissionOutcome struct { + Cancelled *RequestPermissionOutcomeCancelled `json:"-"` + Selected *RequestPermissionOutcomeSelected `json:"-"` +} + +func (o *RequestPermissionOutcome) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + var outcome string + if v, ok := m["outcome"]; ok { + json.Unmarshal(v, &outcome) + } + switch outcome { + case "cancelled": + o.Cancelled = &RequestPermissionOutcomeCancelled{} + return nil + case "selected": + var v2 RequestPermissionOutcomeSelected + if err := json.Unmarshal(b, &v2); err != nil { + return err + } + o.Selected = &v2 + return nil + } + return nil +} +func (o RequestPermissionOutcome) MarshalJSON() ([]byte, error) { + if o.Cancelled != nil { + return json.Marshal(map[string]any{"outcome": "cancelled"}) + } + if o.Selected != nil { + return json.Marshal(map[string]any{ + "optionId": o.Selected.OptionId, + "outcome": "selected", + }) + } + return []byte{}, nil +} + +// Request for user permission to execute a tool call. Sent when the agent needs authorization before performing a sensitive operation. See protocol docs: [Requesting Permission](https://agentclientprotocol.com/protocol/tool-calls#requesting-permission) +type RequestPermissionRequest struct { + // Available permission options for the user to choose from. + Options []PermissionOption `json:"options"` + // The session ID for this request. + SessionId SessionId `json:"sessionId"` + // Details about the tool call requiring permission. + ToolCall ToolCallUpdate `json:"toolCall"` +} + +// Response to a permission request. +type RequestPermissionResponse struct { + // The user's decision on the permission request. + Outcome RequestPermissionOutcome `json:"outcome"` +} + +// A resource that the server is capable of reading, included in a prompt or tool call result. +type ResourceLink struct { + Annotations any `json:"annotations,omitempty"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` + Name string `json:"name"` + Size int `json:"size,omitempty"` + Title string `json:"title,omitempty"` + Uri string `json:"uri"` +} + +// The sender or recipient of messages and data in a conversation. +type Role string + +const ( + RoleAssistant Role = "assistant" + RoleUser Role = "user" +) + +// A unique identifier for a conversation session between a client and agent. Sessions maintain their own context, conversation history, and state, allowing multiple independent interactions with the same agent. # Example ”' use agent_client_protocol::SessionId; use std::sync::Arc; let session_id = SessionId(Arc::from("sess_abc123def456")); ”' See protocol docs: [Session ID](https://agentclientprotocol.com/protocol/session-setup#session-id) +type SessionId string + +// Notification containing a session update from the agent. Used to stream real-time progress and results during prompt processing. See protocol docs: [Agent Reports Output](https://agentclientprotocol.com/protocol/prompt-turn#3-agent-reports-output) +type SessionNotification struct { + // The ID of the session this update pertains to. + SessionId SessionId `json:"sessionId"` + // The actual update content. + Update SessionUpdate `json:"update"` +} + +// Different types of updates that can be sent during session processing. These updates provide real-time feedback about the agent's progress. See protocol docs: [Agent Reports Output](https://agentclientprotocol.com/protocol/prompt-turn#3-agent-reports-output) +type SessionUpdateUserMessageChunk struct { + Content ContentBlock `json:"content"` +} +type SessionUpdateAgentMessageChunk struct { + Content ContentBlock `json:"content"` +} +type SessionUpdateAgentThoughtChunk struct { + Content ContentBlock `json:"content"` +} +type SessionUpdateToolCall struct { + Content []ToolCallContent `json:"content,omitempty"` + Kind ToolKind `json:"kind,omitempty"` + Locations []ToolCallLocation `json:"locations,omitempty"` + RawInput any `json:"rawInput,omitempty"` + RawOutput any `json:"rawOutput,omitempty"` + Status ToolCallStatus `json:"status,omitempty"` + Title string `json:"title"` + ToolCallId ToolCallId `json:"toolCallId"` +} +type SessionUpdateToolCallUpdate struct { + Content []ToolCallContent `json:"content,omitempty"` + Kind any `json:"kind,omitempty"` + Locations []ToolCallLocation `json:"locations,omitempty"` + RawInput any `json:"rawInput,omitempty"` + RawOutput any `json:"rawOutput,omitempty"` + Status any `json:"status,omitempty"` + Title *string `json:"title,omitempty"` + ToolCallId ToolCallId `json:"toolCallId"` +} +type SessionUpdatePlan struct { + Entries []PlanEntry `json:"entries"` +} + +type SessionUpdate struct { + UserMessageChunk *SessionUpdateUserMessageChunk `json:"-"` + AgentMessageChunk *SessionUpdateAgentMessageChunk `json:"-"` + AgentThoughtChunk *SessionUpdateAgentThoughtChunk `json:"-"` + ToolCall *SessionUpdateToolCall `json:"-"` + ToolCallUpdate *SessionUpdateToolCallUpdate `json:"-"` + Plan *SessionUpdatePlan `json:"-"` +} + +func (s *SessionUpdate) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + var kind string + if v, ok := m["sessionUpdate"]; ok { + json.Unmarshal(v, &kind) + } + switch kind { + case "user_message_chunk": + var v SessionUpdateUserMessageChunk + if err := json.Unmarshal(b, &v); err != nil { + return err + } + s.UserMessageChunk = &v + return nil + case "agent_message_chunk": + var v SessionUpdateAgentMessageChunk + if err := json.Unmarshal(b, &v); err != nil { + return err + } + s.AgentMessageChunk = &v + return nil + case "agent_thought_chunk": + var v SessionUpdateAgentThoughtChunk + if err := json.Unmarshal(b, &v); err != nil { + return err + } + s.AgentThoughtChunk = &v + return nil + case "tool_call": + var v SessionUpdateToolCall + if err := json.Unmarshal(b, &v); err != nil { + return err + } + s.ToolCall = &v + return nil + case "tool_call_update": + var v SessionUpdateToolCallUpdate + if err := json.Unmarshal(b, &v); err != nil { + return err + } + s.ToolCallUpdate = &v + return nil + case "plan": + var v SessionUpdatePlan + if err := json.Unmarshal(b, &v); err != nil { + return err + } + s.Plan = &v + return nil + } + return nil +} +func (s SessionUpdate) MarshalJSON() ([]byte, error) { + if s.UserMessageChunk != nil { + return json.Marshal(map[string]any{ + "content": s.UserMessageChunk.Content, + "sessionUpdate": "user_message_chunk", + }) + } + if s.AgentMessageChunk != nil { + return json.Marshal(map[string]any{ + "content": s.AgentMessageChunk.Content, + "sessionUpdate": "agent_message_chunk", + }) + } + if s.AgentThoughtChunk != nil { + return json.Marshal(map[string]any{ + "content": s.AgentThoughtChunk.Content, + "sessionUpdate": "agent_thought_chunk", + }) + } + if s.ToolCall != nil { + return json.Marshal(map[string]any{ + "content": s.ToolCall.Content, + "kind": s.ToolCall.Kind, + "locations": s.ToolCall.Locations, + "rawInput": s.ToolCall.RawInput, + "rawOutput": s.ToolCall.RawOutput, + "sessionUpdate": "tool_call", + "status": s.ToolCall.Status, + "title": s.ToolCall.Title, + "toolCallId": s.ToolCall.ToolCallId, + }) + } + if s.ToolCallUpdate != nil { + return json.Marshal(map[string]any{ + "content": s.ToolCallUpdate.Content, + "kind": s.ToolCallUpdate.Kind, + "locations": s.ToolCallUpdate.Locations, + "rawInput": s.ToolCallUpdate.RawInput, + "rawOutput": s.ToolCallUpdate.RawOutput, + "sessionUpdate": "tool_call_update", + "status": s.ToolCallUpdate.Status, + "title": s.ToolCallUpdate.Title, + "toolCallId": s.ToolCallUpdate.ToolCallId, + }) + } + if s.Plan != nil { + return json.Marshal(map[string]any{ + "entries": s.Plan.Entries, + "sessionUpdate": "plan", + }) + } + return []byte{}, nil +} + +// Reasons why an agent stops processing a prompt turn. See protocol docs: [Stop Reasons](https://agentclientprotocol.com/protocol/prompt-turn#stop-reasons) +type StopReason string + +const ( + StopReasonEndTurn StopReason = "end_turn" + StopReasonMaxTokens StopReason = "max_tokens" + StopReasonMaxTurnRequests StopReason = "max_turn_requests" + StopReasonRefusal StopReason = "refusal" + StopReasonCancelled StopReason = "cancelled" +) + +// Text provided to or from an LLM. +type TextContent struct { + Annotations any `json:"annotations,omitempty"` + Text string `json:"text"` +} + +// Text-based resource contents. +type TextResourceContents struct { + MimeType string `json:"mimeType,omitempty"` + Text string `json:"text"` + Uri string `json:"uri"` +} + +// Represents a tool call that the language model has requested. Tool calls are actions that the agent executes on behalf of the language model, such as reading files, executing code, or fetching data from external sources. See protocol docs: [Tool Calls](https://agentclientprotocol.com/protocol/tool-calls) +type ToolCall struct { + // Content produced by the tool call. + Content []ToolCallContent `json:"content,omitempty"` + // The category of tool being invoked. Helps clients choose appropriate icons and UI treatment. + Kind ToolKind `json:"kind,omitempty"` + // File locations affected by this tool call. Enables "follow-along" features in clients. + Locations []ToolCallLocation `json:"locations,omitempty"` + // Raw input parameters sent to the tool. + RawInput any `json:"rawInput,omitempty"` + // Raw output returned by the tool. + RawOutput any `json:"rawOutput,omitempty"` + // Current execution status of the tool call. + Status ToolCallStatus `json:"status,omitempty"` + // Human-readable title describing what the tool is doing. + Title string `json:"title"` + // Unique identifier for this tool call within the session. + ToolCallId ToolCallId `json:"toolCallId"` +} + +// Content produced by a tool call. Tool calls can produce different types of content including standard content blocks (text, images) or file diffs. See protocol docs: [Content](https://agentclientprotocol.com/protocol/tool-calls#content) +type DiffContent struct { + NewText string `json:"newText"` + OldText *string `json:"oldText,omitempty"` + Path string `json:"path"` +} +type TerminalRef struct { + TerminalId string `json:"terminalId"` +} + +type ToolCallContent struct { + Type string `json:"type"` + Content *ContentBlock `json:"-"` + Diff *DiffContent `json:"-"` + Terminal *TerminalRef `json:"-"` +} + +func (t *ToolCallContent) UnmarshalJSON(b []byte) error { + var probe struct { + Type string `json:"type"` + } + if err := json.Unmarshal(b, &probe); err != nil { + return err + } + t.Type = probe.Type + switch probe.Type { + case "content": + var v struct { + Type string `json:"type"` + Content ContentBlock `json:"content"` + } + if err := json.Unmarshal(b, &v); err != nil { + return err + } + t.Content = &v.Content + case "diff": + var v DiffContent + if err := json.Unmarshal(b, &v); err != nil { + return err + } + t.Diff = &v + case "terminal": + var v TerminalRef + if err := json.Unmarshal(b, &v); err != nil { + return err + } + t.Terminal = &v + } + return nil +} + +// Unique identifier for a tool call within a session. +type ToolCallId string + +// A file location being accessed or modified by a tool. Enables clients to implement "follow-along" features that track which files the agent is working with in real-time. See protocol docs: [Following the Agent](https://agentclientprotocol.com/protocol/tool-calls#following-the-agent) +type ToolCallLocation struct { + // Optional line number within the file. + Line int `json:"line,omitempty"` + // The file path being accessed or modified. + Path string `json:"path"` +} + +// Execution status of a tool call. Tool calls progress through different statuses during their lifecycle. See protocol docs: [Status](https://agentclientprotocol.com/protocol/tool-calls#status) +type ToolCallStatus string + +const ( + ToolCallStatusPending ToolCallStatus = "pending" + ToolCallStatusInProgress ToolCallStatus = "in_progress" + ToolCallStatusCompleted ToolCallStatus = "completed" + ToolCallStatusFailed ToolCallStatus = "failed" +) + +// An update to an existing tool call. Used to report progress and results as tools execute. All fields except the tool call ID are optional - only changed fields need to be included. See protocol docs: [Updating](https://agentclientprotocol.com/protocol/tool-calls#updating) +type ToolCallUpdate struct { + // Replace the content collection. + Content []ToolCallContent `json:"content,omitempty"` + // Update the tool kind. + Kind any `json:"kind,omitempty"` + // Replace the locations collection. + Locations []ToolCallLocation `json:"locations,omitempty"` + // Update the raw input. + RawInput any `json:"rawInput,omitempty"` + // Update the raw output. + RawOutput any `json:"rawOutput,omitempty"` + // Update the execution status. + Status any `json:"status,omitempty"` + // Update the human-readable title. + Title string `json:"title,omitempty"` + // The ID of the tool call being updated. + ToolCallId ToolCallId `json:"toolCallId"` +} + +// Categories of tools that can be invoked. Tool kinds help clients choose appropriate icons and optimize how they display tool execution progress. See protocol docs: [Creating](https://agentclientprotocol.com/protocol/tool-calls#creating) +type ToolKind string + +const ( + ToolKindRead ToolKind = "read" + ToolKindEdit ToolKind = "edit" + ToolKindDelete ToolKind = "delete" + ToolKindMove ToolKind = "move" + ToolKindSearch ToolKind = "search" + ToolKindExecute ToolKind = "execute" + ToolKindThink ToolKind = "think" + ToolKindFetch ToolKind = "fetch" + ToolKindOther ToolKind = "other" +) + +// Request to write content to a text file. Only available if the client supports the 'fs.writeTextFile' capability. +type WriteTextFileRequest struct { + // The text content to write to the file. + Content string `json:"content"` + // Absolute path to the file to write. + Path string `json:"path"` + // The session ID for this request. + SessionId SessionId `json:"sessionId"` +} + +type Agent interface { + Authenticate(params AuthenticateRequest) error + Initialize(params InitializeRequest) (InitializeResponse, error) + Cancel(params CancelNotification) error + LoadSession(params LoadSessionRequest) error + NewSession(params NewSessionRequest) (NewSessionResponse, error) + Prompt(params PromptRequest) (PromptResponse, error) +} +type Client interface { + ReadTextFile(params ReadTextFileRequest) (ReadTextFileResponse, error) + WriteTextFile(params WriteTextFileRequest) error + RequestPermission(params RequestPermissionRequest) (RequestPermissionResponse, error) + SessionUpdate(params SessionNotification) error +} diff --git a/package.json b/package.json index 2bd88b4..2e516ca 100644 --- a/package.json +++ b/package.json @@ -36,13 +36,15 @@ "test:ts:watch": "vitest", "generate:json-schema": "cd rust && cargo run --bin generate --features unstable", "generate:ts-schema": "node typescript/generate.js", - "generate": "npm run generate:json-schema && npm run generate:ts-schema && npm run format", + "generate:go": "cd go/cmd/generate && go run . && cd ../.. && go fmt ./...", + "generate": "npm run generate:json-schema && npm run generate:ts-schema && npm run generate:go && npm run format", "build": "npm run generate && tsc", "format": "prettier --write . && cargo fmt", "format:check": "prettier --check . && cargo fmt -- --check", "lint": "cargo clippy", "lint:fix": "cargo clippy --fix", - "check": "npm run lint && npm run format:check && npm run build && npm run test && npm run docs:ts:verify", + "check:go": "cd go && go build ./...", + "check": "npm run lint && npm run format:check && npm run build && npm run test && npm run docs:ts:verify && npm run check:go", "docs": "cd docs && npx mint dev", "docs:ts:build": "cd typescript && typedoc && echo 'TypeScript documentation generated in ./typescript/docs'", "docs:ts:dev": "concurrently \"cd typescript && typedoc --watch --preserveWatchOutput\" \"npx http-server typescript/docs -p 8081\"", From d4c9caeb8f6c860eb0a838d002b902ba29a1c2a2 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Fri, 29 Aug 2025 21:57:29 +0200 Subject: [PATCH 02/22] feat: add terminal support and improve validation Change-Id: I17041dcb362750bad7e95ab5c84ab707dfaec85f Signed-off-by: Thomas Kosiewski --- .gitignore | 4 + go/acp_test.go | 27 +++- go/agent_gen.go | 35 +++- go/client_gen.go | 70 ++++++-- go/cmd/generate/main.go | 314 +++++++++++++++++++++++++++++------- go/example/agent/main.go | 10 +- go/example/client/main.go | 63 +++++++- go/example/gemini/main.go | 74 ++++++++- go/types.go | 327 +++++++++++++++++++++++++++++++++++--- 9 files changed, 820 insertions(+), 104 deletions(-) diff --git a/.gitignore b/.gitignore index 78fefd3..2e7ea93 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,7 @@ typescript/*.js.map # TypeDoc generated documentation typescript/docs/ + +# Go files +.gocache +.gopath diff --git a/go/acp_test.go b/go/acp_test.go index 0818ff4..743dfe6 100644 --- a/go/acp_test.go +++ b/go/acp_test.go @@ -43,6 +43,21 @@ func (c clientFuncs) SessionUpdate(n SessionNotification) error { return nil } +// Optional/UNSTABLE terminal methods – provide no-op implementations so clientFuncs satisfies Client. +func (c clientFuncs) CreateTerminal(p CreateTerminalRequest) (CreateTerminalResponse, error) { + return CreateTerminalResponse{TerminalId: "term-1"}, nil +} + +func (c clientFuncs) TerminalOutput(p TerminalOutputRequest) (TerminalOutputResponse, error) { + return TerminalOutputResponse{Output: "", Truncated: false}, nil +} + +func (c clientFuncs) ReleaseTerminal(p ReleaseTerminalRequest) error { return nil } + +func (c clientFuncs) WaitForTerminalExit(p WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) { + return WaitForTerminalExitResponse{}, nil +} + type agentFuncs struct { InitializeFunc func(InitializeRequest) (InitializeResponse, error) NewSessionFunc func(NewSessionRequest) (NewSessionResponse, error) @@ -130,7 +145,7 @@ func TestConnectionHandlesErrorsBidirectional(t *testing.T) { } // Agent->Client direction: expect error - if _, err := c.NewSession(NewSessionRequest{Cwd: "/test", McpServers: nil}); err == nil { + if _, err := c.NewSession(NewSessionRequest{Cwd: "/test", McpServers: []McpServer{}}); err == nil { t.Fatalf("expected error for newSession, got nil") } } @@ -239,7 +254,7 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) { CancelFunc: func(p CancelNotification) error { push("cancelled called: " + string(p.SessionId)); return nil }, }, a2cW, c2aR) - if _, err := cs.NewSession(NewSessionRequest{Cwd: "/test", McpServers: nil}); err != nil { + if _, err := cs.NewSession(NewSessionRequest{Cwd: "/test", McpServers: []McpServer{}}); err != nil { t.Fatalf("newSession error: %v", err) } if err := as.WriteTextFile(WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err != nil { @@ -252,8 +267,8 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) { SessionId: "test-session", ToolCall: ToolCallUpdate{ Title: "Execute command", - Kind: "execute", - Status: "pending", + Kind: ptr(ToolKindExecute), + Status: ptr(ToolCallStatusPending), ToolCallId: "tool-123", Content: []ToolCallContent{{ Type: "content", @@ -397,3 +412,7 @@ func TestConnectionHandlesInitialize(t *testing.T) { t.Fatalf("unexpected authMethods: %+v", resp.AuthMethods) } } + +func ptr[T any](t T) *T { + return &t +} diff --git a/go/agent_gen.go b/go/agent_gen.go index 3e02093..7d5686d 100644 --- a/go/agent_gen.go +++ b/go/agent_gen.go @@ -11,6 +11,9 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } if err := a.agent.Authenticate(p); err != nil { return nil, toReqErr(err) } @@ -20,6 +23,9 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } resp, err := a.agent.Initialize(p) if err != nil { return nil, toReqErr(err) @@ -30,6 +36,9 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } if err := a.agent.Cancel(p); err != nil { return nil, toReqErr(err) } @@ -39,6 +48,9 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } if err := a.agent.LoadSession(p); err != nil { return nil, toReqErr(err) } @@ -48,6 +60,9 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } resp, err := a.agent.NewSession(p) if err != nil { return nil, toReqErr(err) @@ -58,6 +73,9 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } resp, err := a.agent.Prompt(p) if err != nil { return nil, toReqErr(err) @@ -71,9 +89,6 @@ func (c *AgentSideConnection) ReadTextFile(params ReadTextFileRequest) (ReadText resp, err := SendRequest[ReadTextFileResponse](c.conn, ClientMethodFsReadTextFile, params) return resp, err } -func (c *AgentSideConnection) SessionUpdate(params SessionNotification) error { - return c.conn.SendNotification(ClientMethodSessionUpdate, params) -} func (c *AgentSideConnection) WriteTextFile(params WriteTextFileRequest) error { return c.conn.SendRequestNoResult(ClientMethodFsWriteTextFile, params) } @@ -81,3 +96,17 @@ func (c *AgentSideConnection) RequestPermission(params RequestPermissionRequest) resp, err := SendRequest[RequestPermissionResponse](c.conn, ClientMethodSessionRequestPermission, params) return resp, err } +func (c *AgentSideConnection) SessionUpdate(params SessionNotification) error { + return c.conn.SendNotification(ClientMethodSessionUpdate, params) +} +func (c *AgentSideConnection) CreateTerminal(params CreateTerminalRequest) (CreateTerminalResponse, error) { + resp, err := SendRequest[CreateTerminalResponse](c.conn, ClientMethodTerminalCreate, params) + return resp, err +} +func (c *AgentSideConnection) TerminalOutput(params TerminalOutputRequest) (TerminalOutputResponse, error) { + resp, err := SendRequest[TerminalOutputResponse](c.conn, ClientMethodTerminalOutput, params) + return resp, err +} +func (c *AgentSideConnection) ReleaseTerminal(params ReleaseTerminalRequest) error { + return c.conn.SendRequestNoResult(ClientMethodTerminalRelease, params) +} diff --git a/go/client_gen.go b/go/client_gen.go index fc5d09f..541de88 100644 --- a/go/client_gen.go +++ b/go/client_gen.go @@ -11,6 +11,9 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } resp, err := c.client.ReadTextFile(p) if err != nil { return nil, toReqErr(err) @@ -21,6 +24,9 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } if err := c.client.WriteTextFile(p); err != nil { return nil, toReqErr(err) } @@ -30,6 +36,9 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } resp, err := c.client.RequestPermission(p) if err != nil { return nil, toReqErr(err) @@ -40,32 +49,73 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } if err := c.client.SessionUpdate(p); err != nil { return nil, toReqErr(err) } return nil, nil + case ClientMethodTerminalCreate: + var p CreateTerminalRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + resp, err := c.client.CreateTerminal(p) + if err != nil { + return nil, toReqErr(err) + } + return resp, nil + case ClientMethodTerminalOutput: + var p TerminalOutputRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + resp, err := c.client.TerminalOutput(p) + if err != nil { + return nil, toReqErr(err) + } + return resp, nil + case ClientMethodTerminalRelease: + var p ReleaseTerminalRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := c.client.ReleaseTerminal(p); err != nil { + return nil, toReqErr(err) + } + return nil, nil default: return nil, NewMethodNotFound(method) } } -func (c *ClientSideConnection) Cancel(params CancelNotification) error { - return c.conn.SendNotification(AgentMethodSessionCancel, params) -} func (c *ClientSideConnection) Authenticate(params AuthenticateRequest) error { return c.conn.SendRequestNoResult(AgentMethodAuthenticate, params) } -func (c *ClientSideConnection) Prompt(params PromptRequest) (PromptResponse, error) { - resp, err := SendRequest[PromptResponse](c.conn, AgentMethodSessionPrompt, params) +func (c *ClientSideConnection) Initialize(params InitializeRequest) (InitializeResponse, error) { + resp, err := SendRequest[InitializeResponse](c.conn, AgentMethodInitialize, params) return resp, err } +func (c *ClientSideConnection) Cancel(params CancelNotification) error { + return c.conn.SendNotification(AgentMethodSessionCancel, params) +} +func (c *ClientSideConnection) LoadSession(params LoadSessionRequest) error { + return c.conn.SendRequestNoResult(AgentMethodSessionLoad, params) +} func (c *ClientSideConnection) NewSession(params NewSessionRequest) (NewSessionResponse, error) { resp, err := SendRequest[NewSessionResponse](c.conn, AgentMethodSessionNew, params) return resp, err } -func (c *ClientSideConnection) Initialize(params InitializeRequest) (InitializeResponse, error) { - resp, err := SendRequest[InitializeResponse](c.conn, AgentMethodInitialize, params) +func (c *ClientSideConnection) Prompt(params PromptRequest) (PromptResponse, error) { + resp, err := SendRequest[PromptResponse](c.conn, AgentMethodSessionPrompt, params) return resp, err } -func (c *ClientSideConnection) LoadSession(params LoadSessionRequest) error { - return c.conn.SendRequestNoResult(AgentMethodSessionLoad, params) -} diff --git a/go/cmd/generate/main.go b/go/cmd/generate/main.go index 9ee340d..68b57cb 100644 --- a/go/cmd/generate/main.go +++ b/go/cmd/generate/main.go @@ -41,6 +41,9 @@ type Definition struct { XMethod string `json:"x-method"` } +// methodInfo captures the association between a wire method and its Go types. +type methodInfo struct{ Side, Method, Req, Resp, Notif string } + func main() { repoRoot := findRepoRoot() schemaDir := filepath.Join(repoRoot, "schema") @@ -164,7 +167,7 @@ func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { for _, name := range keys { def := schema.Defs[name] - if def == nil || def.DocsIgnore { + if def == nil { continue } @@ -235,7 +238,7 @@ func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { if _, ok := req[pk]; !ok { tag = pk + ",omitempty" } - st = append(st, Id(field).Add(jenTypeFor(prop)).Tag(map[string]string{"json": tag})) + st = append(st, Id(field).Add(jenTypeForOptional(prop)).Tag(map[string]string{"json": tag})) } f.Type().Id(name).Struct(st...) f.Line() @@ -248,32 +251,17 @@ func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { f.Type().Id(name).Any() f.Line() } + + // Emit basic validators for RPC/union types + if strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification") || name == "ContentBlock" || name == "ToolCallContent" || name == "SessionUpdate" || name == "ToolCallUpdate" { + emitValidateJen(f, name, def) + } } // Append Agent and Client interfaces derived from meta.json + schema defs { type methodInfo struct{ Side, Method, Req, Resp, Notif string } - groups := map[string]*methodInfo{} - for name, def := range schema.Defs { - if def == nil || def.XMethod == "" || def.XSide == "" { - continue - } - key := def.XSide + "|" + def.XMethod - mi := groups[key] - if mi == nil { - mi = &methodInfo{Side: def.XSide, Method: def.XMethod} - groups[key] = mi - } - if strings.HasSuffix(name, "Request") { - mi.Req = name - } - if strings.HasSuffix(name, "Response") { - mi.Resp = name - } - if strings.HasSuffix(name, "Notification") { - mi.Notif = name - } - } + groups := buildMethodGroups(schema, meta) // Agent methods := []Code{} amKeys := make([]string, 0, len(meta.AgentMethods)) @@ -352,6 +340,174 @@ func isStringConstUnion(def *Definition) bool { return true } +// emitValidateJen generates a simple Validate method for selected types. +func emitValidateJen(f *File, name string, def *Definition) { + switch name { + case "ContentBlock": + f.Func().Params(Id("c").Op("*").Id("ContentBlock")).Id("Validate").Params().Params(Error()).Block( + Switch(Id("c").Dot("Type")).Block( + Case(Lit("text")).Block(If(Id("c").Dot("Text").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.text missing"))))), + Case(Lit("image")).Block(If(Id("c").Dot("Image").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.image missing"))))), + Case(Lit("audio")).Block(If(Id("c").Dot("Audio").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.audio missing"))))), + Case(Lit("resource_link")).Block(If(Id("c").Dot("ResourceLink").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.resource_link missing"))))), + Case(Lit("resource")).Block(If(Id("c").Dot("Resource").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.resource missing"))))), + ), + Return(Nil()), + ) + return + case "ToolCallContent": + f.Func().Params(Id("t").Op("*").Id("ToolCallContent")).Id("Validate").Params().Params(Error()).Block( + Switch(Id("t").Dot("Type")).Block( + Case(Lit("content")).Block(If(Id("t").Dot("Content").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.content missing"))))), + Case(Lit("diff")).Block(If(Id("t").Dot("Diff").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.diff missing"))))), + Case(Lit("terminal")).Block(If(Id("t").Dot("Terminal").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.terminal missing"))))), + ), + Return(Nil()), + ) + return + case "SessionUpdate": + f.Func().Params(Id("s").Op("*").Id("SessionUpdate")).Id("Validate").Params().Params(Error()).Block( + Var().Id("count").Int(), + If(Id("s").Dot("UserMessageChunk").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("AgentMessageChunk").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("AgentThoughtChunk").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("ToolCall").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("ToolCallUpdate").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("Plan").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("count").Op("!=").Lit(1)).Block(Return(Qual("fmt", "Errorf").Call(Lit("sessionupdate must have exactly one variant set")))), + Return(Nil()), + ) + return + case "ToolCallUpdate": + f.Func().Params(Id("t").Op("*").Id("ToolCallUpdate")).Id("Validate").Params().Params(Error()).Block( + If(Id("t").Dot("ToolCallId").Op("==").Lit("")).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolCallId is required")))), + Return(Nil()), + ) + return + } + // Generic RPC objects + if def != nil && primaryType(def) == "object" { + if !(strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification")) { + return + } + f.Func().Params(Id("v").Op("*").Id(name)).Id("Validate").Params().Params(Error()).BlockFunc(func(g *Group) { + // Emit checks in deterministic property order + pkeys := make([]string, 0, len(def.Properties)) + for pk := range def.Properties { + pkeys = append(pkeys, pk) + } + sort.Strings(pkeys) + for _, propName := range pkeys { + pDef := def.Properties[propName] + // is required? + required := false + for _, r := range def.Required { + if r == propName { + required = true + break + } + } + field := toExportedField(propName) + if required { + switch primaryType(pDef) { + case "string": + g.If(Id("v").Dot(field).Op("==").Lit("")).Block(Return(Qual("fmt", "Errorf").Call(Lit(propName + " is required")))) + case "array": + g.If(Id("v").Dot(field).Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit(propName + " is required")))) + } + } + } + g.Return(Nil()) + }) + } +} + +// buildMethodGroups merges schema-provided links with inferred ones from meta. +func buildMethodGroups(schema *Schema, meta *Meta) map[string]*methodInfo { + groups := map[string]*methodInfo{} + // From schema + for name, def := range schema.Defs { + if def == nil || def.XMethod == "" || def.XSide == "" { + continue + } + key := def.XSide + "|" + def.XMethod + mi := groups[key] + if mi == nil { + mi = &methodInfo{Side: def.XSide, Method: def.XMethod} + groups[key] = mi + } + if strings.HasSuffix(name, "Request") { + mi.Req = name + } + if strings.HasSuffix(name, "Response") { + mi.Resp = name + } + if strings.HasSuffix(name, "Notification") { + mi.Notif = name + } + } + // From meta fallback (e.g., terminal methods) + for key, wire := range meta.AgentMethods { + k := "agent|" + wire + if groups[k] == nil { + base := inferTypeBaseFromMethodKey(key) + mi := &methodInfo{Side: "agent", Method: wire} + if wire == "session/cancel" { + mi.Notif = "CancelNotification" + } else { + if _, ok := schema.Defs[base+"Request"]; ok { + mi.Req = base + "Request" + } + if _, ok := schema.Defs[base+"Response"]; ok { + mi.Resp = base + "Response" + } + } + if mi.Req != "" || mi.Notif != "" { + groups[k] = mi + } + } + } + for key, wire := range meta.ClientMethods { + k := "client|" + wire + if groups[k] == nil { + base := inferTypeBaseFromMethodKey(key) + mi := &methodInfo{Side: "client", Method: wire} + if wire == "session/update" { + mi.Notif = "SessionNotification" + } else { + if _, ok := schema.Defs[base+"Request"]; ok { + mi.Req = base + "Request" + } + if _, ok := schema.Defs[base+"Response"]; ok { + mi.Resp = base + "Response" + } + } + if mi.Req != "" || mi.Notif != "" { + groups[k] = mi + } + } + } + return groups +} + +func inferTypeBaseFromMethodKey(methodKey string) string { + parts := strings.Split(methodKey, "_") + if len(parts) == 2 { + n, v := parts[0], parts[1] + switch v { + case "new", "create", "release", "wait", "load", "authenticate", "prompt", "cancel", "read", "write": + return titleWord(v) + titleWord(n) + default: + return titleWord(n) + titleWord(v) + } + } + segs := strings.Split(methodKey, "_") + for i := range segs { + segs[i] = titleWord(segs[i]) + } + return strings.Join(segs, "") +} + func emitContentBlockJen(f *File) { // ResourceLinkContent helper f.Type().Id("ResourceLinkContent").Struct( @@ -791,6 +947,52 @@ func jenTypeFor(d *Definition) Code { } } +// jenTypeForOptional maps unions that include null to pointer types where applicable. +func jenTypeForOptional(d *Definition) Code { + if d == nil { + return Any() + } + // Check anyOf/oneOf with exactly one non-null + null + list := d.AnyOf + if len(list) == 0 { + list = d.OneOf + } + if len(list) == 2 { + var nonNull *Definition + for _, e := range list { + if e == nil { + continue + } + if s, ok := e.Type.(string); ok && s == "null" { + continue + } + if e.Const != nil { + nn := *e + nn.Type = "string" + nonNull = &nn + } else { + nonNull = e + } + } + if nonNull != nil { + if nonNull.Ref != "" && strings.HasPrefix(nonNull.Ref, "#/$defs/") { + return Op("*").Id(nonNull.Ref[len("#/$defs/"):]) + } + switch primaryType(nonNull) { + case "string": + return Op("*").String() + case "integer": + return Op("*").Int() + case "number": + return Op("*").Float64() + case "boolean": + return Op("*").Bool() + } + } + } + return jenTypeFor(d) +} + func isNullResponse(def *Definition) bool { if def == nil { return true @@ -818,32 +1020,8 @@ func dispatchMethodNameForNotification(methodKey, typeName string) string { } func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { - // Build method groups - type methodInfo struct { - Side, Method string - Req, Resp, Notif string - } - groups := map[string]*methodInfo{} - for name, def := range schema.Defs { - if def == nil || def.XMethod == "" || def.XSide == "" { - continue - } - key := def.XSide + "|" + def.XMethod - mi := groups[key] - if mi == nil { - mi = &methodInfo{Side: def.XSide, Method: def.XMethod} - groups[key] = mi - } - if strings.HasSuffix(name, "Request") { - mi.Req = name - } - if strings.HasSuffix(name, "Response") { - mi.Resp = name - } - if strings.HasSuffix(name, "Notification") { - mi.Notif = name - } - } + // Build method groups using schema + meta inference + groups := buildMethodGroups(schema, meta) // Agent handler method fAgent := NewFile("acp") @@ -873,6 +1051,10 @@ func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { ).Block( Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), ), + // Validate if available + If(List(Id("err")).Op(":=").Id("p").Dot("Validate").Call(), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), + ), ) // if err := a.agent.Call(p); err != nil { return nil, toReqErr(err) }; return nil, nil callName := dispatchMethodNameForNotification(k, mi.Notif) @@ -895,6 +1077,9 @@ func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { ).Block( Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), ), + If(List(Id("err")).Op(":=").Id("p").Dot("Validate").Call(), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), + ), ) methodName := strings.TrimSuffix(mi.Req, "Request") if isNullResponse(schema.Defs[respName]) { @@ -936,9 +1121,16 @@ func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { for k, v := range meta.ClientMethods { clientConst[v] = "ClientMethod" + toExportedConst(k) } - // Agent outbound: methods the agent can call on the client - for _, mi := range groups { - if mi.Side != "client" { + // Agent outbound: methods the agent can call on the client (stable order) + cmKeys2 := make([]string, 0, len(meta.ClientMethods)) + for k := range meta.ClientMethods { + cmKeys2 = append(cmKeys2, k) + } + sort.Strings(cmKeys2) + for _, k := range cmKeys2 { + wire := meta.ClientMethods[k] + mi := groups["client|"+wire] + if mi == nil { continue } constName := clientConst[mi.Method] @@ -1001,6 +1193,9 @@ func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), Id("err").Op("!=").Nil()).Block( Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), ), + If(List(Id("err")).Op(":=").Id("p").Dot("Validate").Call(), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), + ), ) callName := dispatchMethodNameForNotification(k, mi.Notif) body = append(body, @@ -1016,6 +1211,9 @@ func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), Id("err").Op("!=").Nil()).Block( Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), ), + If(List(Id("err")).Op(":=").Id("p").Dot("Validate").Call(), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), + ), ) methodName := strings.TrimSuffix(mi.Req, "Request") if isNullResponse(schema.Defs[respName]) { @@ -1044,8 +1242,16 @@ func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { Switch(Id("method")).Block(cCases...), ) // After generating the handler, also append outbound wrappers for ClientSideConnection - for _, mi := range groups { - if mi.Side != "agent" { + // Client outbound: methods the client can call on the agent (stable order) + amKeys2 := make([]string, 0, len(meta.AgentMethods)) + for k := range meta.AgentMethods { + amKeys2 = append(amKeys2, k) + } + sort.Strings(amKeys2) + for _, k := range amKeys2 { + wire := meta.AgentMethods[k] + mi := groups["agent|"+wire] + if mi == nil { continue } constName := agentConst[mi.Method] diff --git a/go/example/agent/main.go b/go/example/agent/main.go index ad1fd8a..779148a 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -172,8 +172,8 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { ToolCall: acp.ToolCallUpdate{ ToolCallId: acp.ToolCallId("call_2"), Title: "Modifying critical configuration file", - Kind: "edit", - Status: "pending", + Kind: ptr(acp.ToolKindEdit), + Status: ptr(acp.ToolCallStatusPending), Locations: []acp.ToolCallLocation{{Path: "/home/user/project/config.json"}}, RawInput: map[string]any{"path": "/home/user/project/config.json", "content": "{\"database\": {\"host\": \"new-host\"}}"}, }, @@ -199,7 +199,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ToolCallUpdate: &acp.SessionUpdateToolCallUpdate{ ToolCallId: acp.ToolCallId("call_2"), - Status: "completed", + Status: ptr(acp.ToolCallStatusCompleted), RawOutput: map[string]any{"success": true, "message": "Configuration updated"}, }}, }); err != nil { @@ -256,6 +256,10 @@ func pause(ctx context.Context, d time.Duration) error { } } +func ptr[T any](t T) *T { + return &t +} + func main() { // Wire up stdio: write to stdout, read from stdin ag := newExampleAgent() diff --git a/go/example/client/main.go b/go/example/client/main.go index c129c36..594d34a 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "strings" acp "github.com/zed-industries/agent-client-protocol/go" @@ -77,13 +78,69 @@ func displayUpdateKind(u acp.SessionUpdate) string { } func (e *exampleClient) WriteTextFile(params acp.WriteTextFileRequest) error { - fmt.Printf("[Client] Write text file called with: %v\n", params) + if !filepath.IsAbs(params.Path) { + return fmt.Errorf("path must be absolute: %s", params.Path) + } + dir := filepath.Dir(params.Path) + if dir != "" { + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + } + if err := os.WriteFile(params.Path, []byte(params.Content), 0o644); err != nil { + return fmt.Errorf("write %s: %w", params.Path, err) + } + fmt.Printf("[Client] Wrote %d bytes to %s\n", len(params.Content), params.Path) return nil } func (e *exampleClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { - fmt.Printf("[Client] Read text file called with: %v\n", params) - return acp.ReadTextFileResponse{Content: "Mock file content"}, nil + if !filepath.IsAbs(params.Path) { + return acp.ReadTextFileResponse{}, fmt.Errorf("path must be absolute: %s", params.Path) + } + b, err := os.ReadFile(params.Path) + if err != nil { + return acp.ReadTextFileResponse{}, fmt.Errorf("read %s: %w", params.Path, err) + } + content := string(b) + // Apply optional line/limit (1-based line index) + if params.Line > 0 || params.Limit > 0 { + lines := strings.Split(content, "\n") + start := 0 + if params.Line > 0 { + start = min(max(params.Line-1, 0), len(lines)) + } + end := len(lines) + if params.Limit > 0 { + if start+params.Limit < end { + end = start + params.Limit + } + } + content = strings.Join(lines[start:end], "\n") + } + fmt.Printf("[Client] ReadTextFile: %s (%d bytes)\n", params.Path, len(content)) + return acp.ReadTextFileResponse{Content: content}, nil +} + +// Optional/UNSTABLE terminal methods: implement as no-ops for example +func (e *exampleClient) CreateTerminal(params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { + fmt.Printf("[Client] CreateTerminal: %v\n", params) + return acp.CreateTerminalResponse{TerminalId: "term-1"}, nil +} + +func (e *exampleClient) TerminalOutput(params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { + fmt.Printf("[Client] TerminalOutput: %v\n", params) + return acp.TerminalOutputResponse{Output: "", Truncated: false}, nil +} + +func (e *exampleClient) ReleaseTerminal(params acp.ReleaseTerminalRequest) error { + fmt.Printf("[Client] ReleaseTerminal: %v\n", params) + return nil +} + +func (e *exampleClient) WaitForTerminalExit(params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { + fmt.Printf("[Client] WaitForTerminalExit: %v\n", params) + return acp.WaitForTerminalExitResponse{}, nil } func main() { diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index 18842db..b431801 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "strings" acp "github.com/zed-industries/agent-client-protocol/go" @@ -63,9 +64,9 @@ func (c *replClient) SessionUpdate(params acp.SessionNotification) error { case u.AgentMessageChunk != nil: content := u.AgentMessageChunk.Content if content.Type == "text" && content.Text != nil { - fmt.Println(content.Text.Text) + fmt.Printf("[agent] \n%s\n", content.Text.Text) } else { - fmt.Printf("[%s]\n", content.Type) + fmt.Printf("[agent] %s\n", content.Type) } case u.ToolCall != nil: fmt.Printf("\n🔧 %s (%s)\n", u.ToolCall.Title, u.ToolCall.Status) @@ -74,7 +75,12 @@ func (c *replClient) SessionUpdate(params acp.SessionNotification) error { case u.Plan != nil: fmt.Println("[plan update]") case u.AgentThoughtChunk != nil: - fmt.Println("[agent_thought_chunk]") + thought := u.AgentThoughtChunk.Content + if thought.Type == "text" && thought.Text != nil { + fmt.Printf("[agent_thought_chunk] \n%s\n", thought.Text.Text) + } else { + fmt.Println("[agent_thought_chunk]", "(", thought.Type, ")") + } case u.UserMessageChunk != nil: fmt.Println("[user_message_chunk]") } @@ -82,14 +88,68 @@ func (c *replClient) SessionUpdate(params acp.SessionNotification) error { } func (c *replClient) WriteTextFile(params acp.WriteTextFileRequest) error { - // For demo purposes, just log the request and allow it. - fmt.Printf("[Client] WriteTextFile: %v\n", params) + if !filepath.IsAbs(params.Path) { + return fmt.Errorf("path must be absolute: %s", params.Path) + } + dir := filepath.Dir(params.Path) + if dir != "" { + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + } + if err := os.WriteFile(params.Path, []byte(params.Content), 0o644); err != nil { + return fmt.Errorf("write %s: %w", params.Path, err) + } + fmt.Printf("[Client] Wrote %d bytes to %s\n", len(params.Content), params.Path) return nil } func (c *replClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { - fmt.Printf("[Client] ReadTextFile: %v\n", params) - return acp.ReadTextFileResponse{Content: "Mock file content"}, nil + if !filepath.IsAbs(params.Path) { + return acp.ReadTextFileResponse{}, fmt.Errorf("path must be absolute: %s", params.Path) + } + b, err := os.ReadFile(params.Path) + if err != nil { + return acp.ReadTextFileResponse{}, fmt.Errorf("read %s: %w", params.Path, err) + } + content := string(b) + if params.Line > 0 || params.Limit > 0 { + lines := strings.Split(content, "\n") + start := 0 + if params.Line > 0 { + start = min(max(params.Line-1, 0), len(lines)) + } + end := len(lines) + if params.Limit > 0 { + if start+params.Limit < end { + end = start + params.Limit + } + } + content = strings.Join(lines[start:end], "\n") + } + fmt.Printf("[Client] ReadTextFile: %s (%d bytes)\n", params.Path, len(content)) + return acp.ReadTextFileResponse{Content: content}, nil +} + +// Optional/UNSTABLE terminal methods: implement as no-ops for example +func (c *replClient) CreateTerminal(params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { + fmt.Printf("[Client] CreateTerminal: %v\n", params) + return acp.CreateTerminalResponse{TerminalId: "term-1"}, nil +} + +func (c *replClient) TerminalOutput(params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { + fmt.Printf("[Client] TerminalOutput: %v\n", params) + return acp.TerminalOutputResponse{Output: "", Truncated: false}, nil +} + +func (c *replClient) ReleaseTerminal(params acp.ReleaseTerminalRequest) error { + fmt.Printf("[Client] ReleaseTerminal: %v\n", params) + return nil +} + +func (c *replClient) WaitForTerminalExit(params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { + fmt.Printf("[Client] WaitForTerminalExit: %v\n", params) + return acp.WaitForTerminalExitResponse{}, nil } func main() { diff --git a/go/types.go b/go/types.go index 58a5b2c..4f86231 100644 --- a/go/types.go +++ b/go/types.go @@ -2,7 +2,10 @@ package acp -import "encoding/json" +import ( + "encoding/json" + "fmt" +) // Capabilities supported by the agent. Advertised during initialization to inform the client about available features and content types. See protocol docs: [Agent Capabilities](https://agentclientprotocol.com/protocol/initialization#agent-capabilities) type AgentCapabilities struct { @@ -12,6 +15,18 @@ type AgentCapabilities struct { PromptCapabilities PromptCapabilities `json:"promptCapabilities,omitempty"` } +// All possible notifications that an agent can send to a client. This enum is used internally for routing RPC notifications. You typically won't need to use this directly - use the notification methods on the ['Client'] trait instead. Notifications do not expect a response. +// AgentNotification is a union or complex schema; represented generically. +type AgentNotification any + +// All possible requests that an agent can send to a client. This enum is used internally for routing RPC requests. You typically won't need to use this directly - instead, use the methods on the ['Client'] trait. This enum encompasses all method calls from agent to client. +// AgentRequest is a union or complex schema; represented generically. +type AgentRequest any + +// All possible responses that an agent can send to a client. This enum is used internally for routing RPC responses. You typically won't need to use this directly - the responses are handled automatically by the connection. These are responses to the corresponding ClientRequest variants. +// AgentResponse is a union or complex schema; represented generically. +type AgentResponse any + // Optional annotations for the client. The client can use annotations to inform how objects are used or displayed type Annotations struct { Audience []Role `json:"audience,omitempty"` @@ -21,9 +36,9 @@ type Annotations struct { // Audio provided to or from an LLM. type AudioContent struct { - Annotations any `json:"annotations,omitempty"` - Data string `json:"data"` - MimeType string `json:"mimeType"` + Annotations *Annotations `json:"annotations,omitempty"` + Data string `json:"data"` + MimeType string `json:"mimeType"` } // Describes an available authentication method. @@ -45,6 +60,10 @@ type AuthenticateRequest struct { MethodId AuthMethodId `json:"methodId"` } +func (v *AuthenticateRequest) Validate() error { + return nil +} + // Binary resource contents. type BlobResourceContents struct { Blob string `json:"blob"` @@ -58,6 +77,10 @@ type CancelNotification struct { SessionId SessionId `json:"sessionId"` } +func (v *CancelNotification) Validate() error { + return nil +} + // Capabilities supported by the client. Advertised during initialization to inform the agent about available features and methods. See protocol docs: [Client Capabilities](https://agentclientprotocol.com/protocol/initialization#client-capabilities) type ClientCapabilities struct { // File system capabilities supported by the client. Determines which file operations the agent can request. @@ -66,6 +89,18 @@ type ClientCapabilities struct { Terminal bool `json:"terminal,omitempty"` } +// All possible notifications that a client can send to an agent. This enum is used internally for routing RPC notifications. You typically won't need to use this directly - use the notification methods on the ['Agent'] trait instead. Notifications do not expect a response. +// ClientNotification is a union or complex schema; represented generically. +type ClientNotification any + +// All possible requests that a client can send to an agent. This enum is used internally for routing RPC requests. You typically won't need to use this directly - instead, use the methods on the ['Agent'] trait. This enum encompasses all method calls from client to agent. +// ClientRequest is a union or complex schema; represented generically. +type ClientRequest any + +// All possible responses that a client can send to an agent. This enum is used internally for routing RPC responses. You typically won't need to use this directly - the responses are handled automatically by the connection. These are responses to the corresponding AgentRequest variants. +// ClientResponse is a union or complex schema; represented generically. +type ClientResponse any + // Content blocks represent displayable information in the Agent Client Protocol. They provide a structured way to handle various types of user-facing content—whether it's text from language models, images for analysis, or embedded resources for context. Content blocks appear in: - User prompts sent via 'session/prompt' - Language model output streamed through 'session/update' notifications - Progress updates and results from tool calls This structure is compatible with the Model Context Protocol (MCP), enabling agents to seamlessly forward content from MCP tool outputs without transformation. See protocol docs: [Content](https://agentclientprotocol.com/protocol/content) type ResourceLinkContent struct { Annotations any `json:"annotations,omitempty"` @@ -177,9 +212,62 @@ func (c ContentBlock) MarshalJSON() ([]byte, error) { return []byte{}, nil } +func (c *ContentBlock) Validate() error { + switch c.Type { + case "text": + if c.Text == nil { + return fmt.Errorf("contentblock.text missing") + } + case "image": + if c.Image == nil { + return fmt.Errorf("contentblock.image missing") + } + case "audio": + if c.Audio == nil { + return fmt.Errorf("contentblock.audio missing") + } + case "resource_link": + if c.ResourceLink == nil { + return fmt.Errorf("contentblock.resource_link missing") + } + case "resource": + if c.Resource == nil { + return fmt.Errorf("contentblock.resource missing") + } + } + return nil +} + +type CreateTerminalRequest struct { + Args []string `json:"args,omitempty"` + Command string `json:"command"` + Cwd string `json:"cwd,omitempty"` + Env []EnvVariable `json:"env,omitempty"` + OutputByteLimit int `json:"outputByteLimit,omitempty"` + SessionId SessionId `json:"sessionId"` +} + +func (v *CreateTerminalRequest) Validate() error { + if v.Command == "" { + return fmt.Errorf("command is required") + } + return nil +} + +type CreateTerminalResponse struct { + TerminalId string `json:"terminalId"` +} + +func (v *CreateTerminalResponse) Validate() error { + if v.TerminalId == "" { + return fmt.Errorf("terminalId is required") + } + return nil +} + // The contents of a resource, embedded into a prompt or tool call result. type EmbeddedResource struct { - Annotations any `json:"annotations,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` Resource EmbeddedResourceResource `json:"resource"` } @@ -231,10 +319,10 @@ type FileSystemCapability struct { // An image provided to or from an LLM. type ImageContent struct { - Annotations any `json:"annotations,omitempty"` - Data string `json:"data"` - MimeType string `json:"mimeType"` - Uri string `json:"uri,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + Data string `json:"data"` + MimeType string `json:"mimeType"` + Uri string `json:"uri,omitempty"` } // Request parameters for the initialize method. Sent by the client to establish connection and negotiate capabilities. See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) @@ -245,6 +333,10 @@ type InitializeRequest struct { ProtocolVersion ProtocolVersion `json:"protocolVersion"` } +func (v *InitializeRequest) Validate() error { + return nil +} + // Response from the initialize method. Contains the negotiated protocol version and agent capabilities. See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) type InitializeResponse struct { // Capabilities supported by the agent. @@ -255,6 +347,10 @@ type InitializeResponse struct { ProtocolVersion ProtocolVersion `json:"protocolVersion"` } +func (v *InitializeResponse) Validate() error { + return nil +} + // Request parameters for loading an existing session. Only available if the agent supports the 'loadSession' capability. See protocol docs: [Loading Sessions](https://agentclientprotocol.com/protocol/session-setup#loading-sessions) type LoadSessionRequest struct { // The working directory for this session. @@ -265,6 +361,16 @@ type LoadSessionRequest struct { SessionId SessionId `json:"sessionId"` } +func (v *LoadSessionRequest) Validate() error { + if v.Cwd == "" { + return fmt.Errorf("cwd is required") + } + if v.McpServers == nil { + return fmt.Errorf("mcpServers is required") + } + return nil +} + // Configuration for connecting to an MCP (Model Context Protocol) server. MCP servers provide tools and context that the agent can use when processing prompts. See protocol docs: [MCP Servers](https://agentclientprotocol.com/protocol/session-setup#mcp-servers) type McpServer struct { // Command-line arguments to pass to the MCP server. @@ -285,12 +391,26 @@ type NewSessionRequest struct { McpServers []McpServer `json:"mcpServers"` } +func (v *NewSessionRequest) Validate() error { + if v.Cwd == "" { + return fmt.Errorf("cwd is required") + } + if v.McpServers == nil { + return fmt.Errorf("mcpServers is required") + } + return nil +} + // Response from creating a new session. See protocol docs: [Creating a Session](https://agentclientprotocol.com/protocol/session-setup#creating-a-session) type NewSessionResponse struct { // Unique identifier for the created session. Used in all subsequent requests for this conversation. SessionId SessionId `json:"sessionId"` } +func (v *NewSessionResponse) Validate() error { + return nil +} + // An option presented to the user when requesting permission. type PermissionOption struct { // Hint about the nature of this permission option. @@ -366,12 +486,23 @@ type PromptRequest struct { SessionId SessionId `json:"sessionId"` } +func (v *PromptRequest) Validate() error { + if v.Prompt == nil { + return fmt.Errorf("prompt is required") + } + return nil +} + // Response from processing a user prompt. See protocol docs: [Check for Completion](https://agentclientprotocol.com/protocol/prompt-turn#4-check-for-completion) type PromptResponse struct { // Indicates why the agent stopped processing the turn. StopReason StopReason `json:"stopReason"` } +func (v *PromptResponse) Validate() error { + return nil +} + // Protocol version identifier. This version is only bumped for breaking changes. Non-breaking changes should be introduced via capabilities. type ProtocolVersion int @@ -387,11 +518,37 @@ type ReadTextFileRequest struct { SessionId SessionId `json:"sessionId"` } +func (v *ReadTextFileRequest) Validate() error { + if v.Path == "" { + return fmt.Errorf("path is required") + } + return nil +} + // Response containing the contents of a text file. type ReadTextFileResponse struct { Content string `json:"content"` } +func (v *ReadTextFileResponse) Validate() error { + if v.Content == "" { + return fmt.Errorf("content is required") + } + return nil +} + +type ReleaseTerminalRequest struct { + SessionId SessionId `json:"sessionId"` + TerminalId string `json:"terminalId"` +} + +func (v *ReleaseTerminalRequest) Validate() error { + if v.TerminalId == "" { + return fmt.Errorf("terminalId is required") + } + return nil +} + // The outcome of a permission request. type RequestPermissionOutcomeCancelled struct{} type RequestPermissionOutcomeSelected struct { @@ -449,21 +606,32 @@ type RequestPermissionRequest struct { ToolCall ToolCallUpdate `json:"toolCall"` } +func (v *RequestPermissionRequest) Validate() error { + if v.Options == nil { + return fmt.Errorf("options is required") + } + return nil +} + // Response to a permission request. type RequestPermissionResponse struct { // The user's decision on the permission request. Outcome RequestPermissionOutcome `json:"outcome"` } +func (v *RequestPermissionResponse) Validate() error { + return nil +} + // A resource that the server is capable of reading, included in a prompt or tool call result. type ResourceLink struct { - Annotations any `json:"annotations,omitempty"` - Description string `json:"description,omitempty"` - MimeType string `json:"mimeType,omitempty"` - Name string `json:"name"` - Size int `json:"size,omitempty"` - Title string `json:"title,omitempty"` - Uri string `json:"uri"` + Annotations *Annotations `json:"annotations,omitempty"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` + Name string `json:"name"` + Size int `json:"size,omitempty"` + Title string `json:"title,omitempty"` + Uri string `json:"uri"` } // The sender or recipient of messages and data in a conversation. @@ -485,6 +653,10 @@ type SessionNotification struct { Update SessionUpdate `json:"update"` } +func (v *SessionNotification) Validate() error { + return nil +} + // Different types of updates that can be sent during session processing. These updates provide real-time feedback about the agent's progress. See protocol docs: [Agent Reports Output](https://agentclientprotocol.com/protocol/prompt-turn#3-agent-reports-output) type SessionUpdateUserMessageChunk struct { Content ContentBlock `json:"content"` @@ -637,6 +809,32 @@ func (s SessionUpdate) MarshalJSON() ([]byte, error) { return []byte{}, nil } +func (s *SessionUpdate) Validate() error { + var count int + if s.UserMessageChunk != nil { + count++ + } + if s.AgentMessageChunk != nil { + count++ + } + if s.AgentThoughtChunk != nil { + count++ + } + if s.ToolCall != nil { + count++ + } + if s.ToolCallUpdate != nil { + count++ + } + if s.Plan != nil { + count++ + } + if count != 1 { + return fmt.Errorf("sessionupdate must have exactly one variant set") + } + return nil +} + // Reasons why an agent stops processing a prompt turn. See protocol docs: [Stop Reasons](https://agentclientprotocol.com/protocol/prompt-turn#stop-reasons) type StopReason string @@ -648,10 +846,40 @@ const ( StopReasonCancelled StopReason = "cancelled" ) +type TerminalExitStatus struct { + ExitCode int `json:"exitCode,omitempty"` + Signal string `json:"signal,omitempty"` +} + +type TerminalOutputRequest struct { + SessionId SessionId `json:"sessionId"` + TerminalId string `json:"terminalId"` +} + +func (v *TerminalOutputRequest) Validate() error { + if v.TerminalId == "" { + return fmt.Errorf("terminalId is required") + } + return nil +} + +type TerminalOutputResponse struct { + ExitStatus *TerminalExitStatus `json:"exitStatus,omitempty"` + Output string `json:"output"` + Truncated bool `json:"truncated"` +} + +func (v *TerminalOutputResponse) Validate() error { + if v.Output == "" { + return fmt.Errorf("output is required") + } + return nil +} + // Text provided to or from an LLM. type TextContent struct { - Annotations any `json:"annotations,omitempty"` - Text string `json:"text"` + Annotations *Annotations `json:"annotations,omitempty"` + Text string `json:"text"` } // Text-based resource contents. @@ -732,6 +960,24 @@ func (t *ToolCallContent) UnmarshalJSON(b []byte) error { return nil } +func (t *ToolCallContent) Validate() error { + switch t.Type { + case "content": + if t.Content == nil { + return fmt.Errorf("toolcallcontent.content missing") + } + case "diff": + if t.Diff == nil { + return fmt.Errorf("toolcallcontent.diff missing") + } + case "terminal": + if t.Terminal == nil { + return fmt.Errorf("toolcallcontent.terminal missing") + } + } + return nil +} + // Unique identifier for a tool call within a session. type ToolCallId string @@ -758,7 +1004,7 @@ type ToolCallUpdate struct { // Replace the content collection. Content []ToolCallContent `json:"content,omitempty"` // Update the tool kind. - Kind any `json:"kind,omitempty"` + Kind *ToolKind `json:"kind,omitempty"` // Replace the locations collection. Locations []ToolCallLocation `json:"locations,omitempty"` // Update the raw input. @@ -766,13 +1012,20 @@ type ToolCallUpdate struct { // Update the raw output. RawOutput any `json:"rawOutput,omitempty"` // Update the execution status. - Status any `json:"status,omitempty"` + Status *ToolCallStatus `json:"status,omitempty"` // Update the human-readable title. Title string `json:"title,omitempty"` // The ID of the tool call being updated. ToolCallId ToolCallId `json:"toolCallId"` } +func (t *ToolCallUpdate) Validate() error { + if t.ToolCallId == "" { + return fmt.Errorf("toolCallId is required") + } + return nil +} + // Categories of tools that can be invoked. Tool kinds help clients choose appropriate icons and optimize how they display tool execution progress. See protocol docs: [Creating](https://agentclientprotocol.com/protocol/tool-calls#creating) type ToolKind string @@ -788,6 +1041,27 @@ const ( ToolKindOther ToolKind = "other" ) +type WaitForTerminalExitRequest struct { + SessionId SessionId `json:"sessionId"` + TerminalId string `json:"terminalId"` +} + +func (v *WaitForTerminalExitRequest) Validate() error { + if v.TerminalId == "" { + return fmt.Errorf("terminalId is required") + } + return nil +} + +type WaitForTerminalExitResponse struct { + ExitCode int `json:"exitCode,omitempty"` + Signal string `json:"signal,omitempty"` +} + +func (v *WaitForTerminalExitResponse) Validate() error { + return nil +} + // Request to write content to a text file. Only available if the client supports the 'fs.writeTextFile' capability. type WriteTextFileRequest struct { // The text content to write to the file. @@ -798,6 +1072,16 @@ type WriteTextFileRequest struct { SessionId SessionId `json:"sessionId"` } +func (v *WriteTextFileRequest) Validate() error { + if v.Content == "" { + return fmt.Errorf("content is required") + } + if v.Path == "" { + return fmt.Errorf("path is required") + } + return nil +} + type Agent interface { Authenticate(params AuthenticateRequest) error Initialize(params InitializeRequest) (InitializeResponse, error) @@ -811,4 +1095,7 @@ type Client interface { WriteTextFile(params WriteTextFileRequest) error RequestPermission(params RequestPermissionRequest) (RequestPermissionResponse, error) SessionUpdate(params SessionNotification) error + CreateTerminal(params CreateTerminalRequest) (CreateTerminalResponse, error) + TerminalOutput(params TerminalOutputRequest) (TerminalOutputResponse, error) + ReleaseTerminal(params ReleaseTerminalRequest) error } From 7784e3ee41936e81ccc9678c52bec6b9b53ddc1e Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Sat, 30 Aug 2025 14:38:16 +0200 Subject: [PATCH 03/22] feat: refactor interfaces to separate stable, experimental, and optional methods Change-Id: I4d1ebe3b6f5b53a52e1ec9c13e4028b4d05a6b5e Signed-off-by: Thomas Kosiewski --- go/acp_test.go | 22 ++-- go/agent_gen.go | 10 +- go/client_gen.go | 35 +++++- go/cmd/generate/main.go | 223 ++++++++++++++++++++++++++++++++------ go/example/agent/main.go | 5 +- go/example/client/main.go | 12 +- go/example/gemini/main.go | 16 ++- go/types.go | 11 +- 8 files changed, 275 insertions(+), 59 deletions(-) diff --git a/go/acp_test.go b/go/acp_test.go index 743dfe6..d2558bd 100644 --- a/go/acp_test.go +++ b/go/acp_test.go @@ -15,6 +15,8 @@ type clientFuncs struct { SessionUpdateFunc func(SessionNotification) error } +var _ Client = (*clientFuncs)(nil) + func (c clientFuncs) WriteTextFile(p WriteTextFileRequest) error { if c.WriteTextFileFunc != nil { return c.WriteTextFileFunc(p) @@ -43,21 +45,6 @@ func (c clientFuncs) SessionUpdate(n SessionNotification) error { return nil } -// Optional/UNSTABLE terminal methods – provide no-op implementations so clientFuncs satisfies Client. -func (c clientFuncs) CreateTerminal(p CreateTerminalRequest) (CreateTerminalResponse, error) { - return CreateTerminalResponse{TerminalId: "term-1"}, nil -} - -func (c clientFuncs) TerminalOutput(p TerminalOutputRequest) (TerminalOutputResponse, error) { - return TerminalOutputResponse{Output: "", Truncated: false}, nil -} - -func (c clientFuncs) ReleaseTerminal(p ReleaseTerminalRequest) error { return nil } - -func (c clientFuncs) WaitForTerminalExit(p WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) { - return WaitForTerminalExitResponse{}, nil -} - type agentFuncs struct { InitializeFunc func(InitializeRequest) (InitializeResponse, error) NewSessionFunc func(NewSessionRequest) (NewSessionResponse, error) @@ -67,6 +54,11 @@ type agentFuncs struct { CancelFunc func(CancelNotification) error } +var ( + _ Agent = (*agentFuncs)(nil) + _ AgentLoader = (*agentFuncs)(nil) +) + func (a agentFuncs) Initialize(p InitializeRequest) (InitializeResponse, error) { if a.InitializeFunc != nil { return a.InitializeFunc(p) diff --git a/go/agent_gen.go b/go/agent_gen.go index 7d5686d..9897a38 100644 --- a/go/agent_gen.go +++ b/go/agent_gen.go @@ -51,7 +51,11 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - if err := a.agent.LoadSession(p); err != nil { + loader, ok := a.agent.(AgentLoader) + if !ok { + return nil, NewMethodNotFound(method) + } + if err := loader.LoadSession(p); err != nil { return nil, toReqErr(err) } return nil, nil @@ -110,3 +114,7 @@ func (c *AgentSideConnection) TerminalOutput(params TerminalOutputRequest) (Term func (c *AgentSideConnection) ReleaseTerminal(params ReleaseTerminalRequest) error { return c.conn.SendRequestNoResult(ClientMethodTerminalRelease, params) } +func (c *AgentSideConnection) WaitForTerminalExit(params WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) { + resp, err := SendRequest[WaitForTerminalExitResponse](c.conn, ClientMethodTerminalWaitForExit, params) + return resp, err +} diff --git a/go/client_gen.go b/go/client_gen.go index 541de88..09f592c 100644 --- a/go/client_gen.go +++ b/go/client_gen.go @@ -64,7 +64,11 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - resp, err := c.client.CreateTerminal(p) + t, ok := c.client.(ClientExperimental) + if !ok { + return nil, NewMethodNotFound(method) + } + resp, err := t.CreateTerminal(p) if err != nil { return nil, toReqErr(err) } @@ -77,7 +81,11 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - resp, err := c.client.TerminalOutput(p) + t, ok := c.client.(ClientExperimental) + if !ok { + return nil, NewMethodNotFound(method) + } + resp, err := t.TerminalOutput(p) if err != nil { return nil, toReqErr(err) } @@ -90,10 +98,31 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - if err := c.client.ReleaseTerminal(p); err != nil { + t, ok := c.client.(ClientExperimental) + if !ok { + return nil, NewMethodNotFound(method) + } + if err := t.ReleaseTerminal(p); err != nil { return nil, toReqErr(err) } return nil, nil + case ClientMethodTerminalWaitForExit: + var p WaitForTerminalExitRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + t, ok := c.client.(ClientExperimental) + if !ok { + return nil, NewMethodNotFound(method) + } + resp, err := t.WaitForTerminalExit(p) + if err != nil { + return nil, toReqErr(err) + } + return resp, nil default: return nil, NewMethodNotFound(method) } diff --git a/go/cmd/generate/main.go b/go/cmd/generate/main.go index 68b57cb..720801c 100644 --- a/go/cmd/generate/main.go +++ b/go/cmd/generate/main.go @@ -260,10 +260,37 @@ func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { // Append Agent and Client interfaces derived from meta.json + schema defs { - type methodInfo struct{ Side, Method, Req, Resp, Notif string } groups := buildMethodGroups(schema, meta) + + // Helper: determine if a method is undocumented (x-docs-ignore) + isDocsIgnored := func(mi *methodInfo) bool { + if mi == nil { + return false + } + if mi.Req != "" { + if d := schema.Defs[mi.Req]; d != nil && d.DocsIgnore { + return true + } + } + if mi.Resp != "" { + if d := schema.Defs[mi.Resp]; d != nil && d.DocsIgnore { + return true + } + } + if mi.Notif != "" { + if d := schema.Defs[mi.Notif]; d != nil && d.DocsIgnore { + return true + } + } + return false + } + // Agent - methods := []Code{} + agentMethods := []Code{} + // Optional loader methods live on a separate interface + agentLoaderMethods := []Code{} + // Undocumented/experimental methods live on a separate interface + agentExperimentalMethods := []Code{} amKeys := make([]string, 0, len(meta.AgentMethods)) for k := range meta.AgentMethods { amKeys = append(amKeys, k) @@ -275,22 +302,42 @@ func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { if mi == nil { continue } + // Treat session/load as optional (AgentLoader) + target := &agentMethods + if wire == "session/load" { + target = &agentLoaderMethods + } + // Undocumented/experimental agent methods go to AgentExperimental + if isDocsIgnored(mi) { + target = &agentExperimentalMethods + } if mi.Notif != "" { name := dispatchMethodNameForNotification(k, mi.Notif) - methods = append(methods, Id(name).Params(Id("params").Id(mi.Notif)).Error()) + *target = append(*target, Id(name).Params(Id("params").Id(mi.Notif)).Error()) } else if mi.Req != "" { respName := strings.TrimSuffix(mi.Req, "Request") + "Response" methodName := strings.TrimSuffix(mi.Req, "Request") if isNullResponse(schema.Defs[respName]) { - methods = append(methods, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) + *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) } else { - methods = append(methods, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) + *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) } } } - f.Type().Id("Agent").Interface(methods...) + // Emit interfaces + f.Type().Id("Agent").Interface(agentMethods...) + if len(agentLoaderMethods) > 0 { + f.Comment("AgentLoader defines optional support for loading sessions. Implement and advertise the capability to enable 'session/load'.") + f.Type().Id("AgentLoader").Interface(agentLoaderMethods...) + } + if len(agentExperimentalMethods) > 0 { + f.Comment("AgentExperimental defines undocumented/experimental methods (x-docs-ignore). These may change or be removed without notice.") + f.Type().Id("AgentExperimental").Interface(agentExperimentalMethods...) + } + // Client - methods = []Code{} + clientStable := []Code{} + clientExperimental := []Code{} cmKeys := make([]string, 0, len(meta.ClientMethods)) for k := range meta.ClientMethods { cmKeys = append(cmKeys, k) @@ -302,20 +349,28 @@ func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { if mi == nil { continue } + target := &clientStable + if isDocsIgnored(mi) { + target = &clientExperimental + } if mi.Notif != "" { name := dispatchMethodNameForNotification(k, mi.Notif) - methods = append(methods, Id(name).Params(Id("params").Id(mi.Notif)).Error()) + *target = append(*target, Id(name).Params(Id("params").Id(mi.Notif)).Error()) } else if mi.Req != "" { respName := strings.TrimSuffix(mi.Req, "Request") + "Response" methodName := strings.TrimSuffix(mi.Req, "Request") if isNullResponse(schema.Defs[respName]) { - methods = append(methods, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) + *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) } else { - methods = append(methods, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) + *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) } } } - f.Type().Id("Client").Interface(methods...) + f.Type().Id("Client").Interface(clientStable...) + if len(clientExperimental) > 0 { + f.Comment("ClientExperimental defines undocumented/experimental methods (x-docs-ignore), such as terminal support. Implement and advertise the related capability to enable them.") + f.Type().Id("ClientExperimental").Interface(clientExperimental...) + } } var buf bytes.Buffer @@ -340,6 +395,30 @@ func isStringConstUnion(def *Definition) bool { return true } +// isDocsIgnoredMethod returns true if any of the method's associated types +// (request, response, notification) are marked with x-docs-ignore in the schema. +func isDocsIgnoredMethod(schema *Schema, mi *methodInfo) bool { + if mi == nil { + return false + } + if mi.Req != "" { + if d := schema.Defs[mi.Req]; d != nil && d.DocsIgnore { + return true + } + } + if mi.Resp != "" { + if d := schema.Defs[mi.Resp]; d != nil && d.DocsIgnore { + return true + } + } + if mi.Notif != "" { + if d := schema.Defs[mi.Notif]; d != nil && d.DocsIgnore { + return true + } + } + return false +} + // emitValidateJen generates a simple Validate method for selected types. func emitValidateJen(f *File, name string, def *Definition) { switch name { @@ -491,6 +570,10 @@ func buildMethodGroups(schema *Schema, meta *Meta) map[string]*methodInfo { } func inferTypeBaseFromMethodKey(methodKey string) string { + // Special-case known irregular mappings + if methodKey == "terminal_wait_for_exit" { + return "WaitForTerminalExit" + } parts := strings.Split(methodKey, "_") if len(parts) == 2 { n, v := parts[0], parts[1] @@ -1082,22 +1165,75 @@ func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { ), ) methodName := strings.TrimSuffix(mi.Req, "Request") - if isNullResponse(schema.Defs[respName]) { + // Optional: session/load lives on AgentLoader + if wire == "session/load" { + // Perform type assertion first, then branch caseBody = append(caseBody, - If( - List(Id("err")).Op(":=").Id("a").Dot("agent").Dot(methodName).Call(Id("p")), - Id("err").Op("!=").Nil(), - ).Block( - Return(Nil(), Id("toReqErr").Call(Id("err"))), + List(Id("loader"), Id("ok")).Op(":=").Id("a").Dot("agent").Assert(Id("AgentLoader")), + If(Op("!").Id("ok")).Block( + Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))), ), - Return(Nil(), Nil()), ) - } else { + if isNullResponse(schema.Defs[respName]) { + caseBody = append(caseBody, + If( + List(Id("err")).Op(":=").Id("loader").Dot(methodName).Call(Id("p")), + Id("err").Op("!=").Nil(), + ).Block( + Return(Nil(), Id("toReqErr").Call(Id("err"))), + ), + Return(Nil(), Nil()), + ) + } else { + caseBody = append(caseBody, + List(Id("resp"), Id("err")).Op(":=").Id("loader").Dot(methodName).Call(Id("p")), + If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), + Return(Id("resp"), Nil()), + ) + } + } else if isDocsIgnoredMethod(schema, mi) { + // Undocumented/experimental agent methods require AgentExperimental caseBody = append(caseBody, - List(Id("resp"), Id("err")).Op(":=").Id("a").Dot("agent").Dot(methodName).Call(Id("p")), - If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), - Return(Id("resp"), Nil()), + List(Id("exp"), Id("ok")).Op(":=").Id("a").Dot("agent").Assert(Id("AgentExperimental")), + If(Op("!").Id("ok")).Block( + Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))), + ), ) + if isNullResponse(schema.Defs[respName]) { + caseBody = append(caseBody, + If( + List(Id("err")).Op(":=").Id("exp").Dot(methodName).Call(Id("p")), + Id("err").Op("!=").Nil(), + ).Block( + Return(Nil(), Id("toReqErr").Call(Id("err"))), + ), + Return(Nil(), Nil()), + ) + } else { + caseBody = append(caseBody, + List(Id("resp"), Id("err")).Op(":=").Id("exp").Dot(methodName).Call(Id("p")), + If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), + Return(Id("resp"), Nil()), + ) + } + } else { + if isNullResponse(schema.Defs[respName]) { + caseBody = append(caseBody, + If( + List(Id("err")).Op(":=").Id("a").Dot("agent").Dot(methodName).Call(Id("p")), + Id("err").Op("!=").Nil(), + ).Block( + Return(Nil(), Id("toReqErr").Call(Id("err"))), + ), + Return(Nil(), Nil()), + ) + } else { + caseBody = append(caseBody, + List(Id("resp"), Id("err")).Op(":=").Id("a").Dot("agent").Dot(methodName).Call(Id("p")), + If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), + Return(Id("resp"), Nil()), + ) + } } } if len(caseBody) > 0 { @@ -1216,19 +1352,44 @@ func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { ), ) methodName := strings.TrimSuffix(mi.Req, "Request") - if isNullResponse(schema.Defs[respName]) { + // Optional/experimental undocumented methods: require ClientExperimental + if isDocsIgnoredMethod(schema, mi) { + // Perform type assertion first, then branch body = append(body, - If(List(Id("err")).Op(":=").Id("c").Dot("client").Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("toReqErr").Call(Id("err"))), + List(Id("t"), Id("ok")).Op(":=").Id("c").Dot("client").Assert(Id("ClientExperimental")), + If(Op("!").Id("ok")).Block( + Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))), ), - Return(Nil(), Nil()), ) + if isNullResponse(schema.Defs[respName]) { + body = append(body, + If(List(Id("err")).Op(":=").Id("t").Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("toReqErr").Call(Id("err"))), + ), + Return(Nil(), Nil()), + ) + } else { + body = append(body, + List(Id("resp"), Id("err")).Op(":=").Id("t").Dot(methodName).Call(Id("p")), + If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), + Return(Id("resp"), Nil()), + ) + } } else { - body = append(body, - List(Id("resp"), Id("err")).Op(":=").Id("c").Dot("client").Dot(methodName).Call(Id("p")), - If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), - Return(Id("resp"), Nil()), - ) + if isNullResponse(schema.Defs[respName]) { + body = append(body, + If(List(Id("err")).Op(":=").Id("c").Dot("client").Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block( + Return(Nil(), Id("toReqErr").Call(Id("err"))), + ), + Return(Nil(), Nil()), + ) + } else { + body = append(body, + List(Id("resp"), Id("err")).Op(":=").Id("c").Dot("client").Dot(methodName).Call(Id("p")), + If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), + Return(Id("resp"), Nil()), + ) + } } } if len(body) > 0 { diff --git a/go/example/agent/main.go b/go/example/agent/main.go index 779148a..ab7df17 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -21,7 +21,10 @@ type exampleAgent struct { sessions map[string]*agentSession } -var _ acp.Agent = (*exampleAgent)(nil) +var ( + _ acp.Agent = (*exampleAgent)(nil) + _ acp.AgentLoader = (*exampleAgent)(nil) +) func newExampleAgent() *exampleAgent { return &exampleAgent{sessions: make(map[string]*agentSession)} diff --git a/go/example/client/main.go b/go/example/client/main.go index 594d34a..e204488 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -13,6 +13,11 @@ import ( type exampleClient struct{} +var ( + _ acp.Client = (*exampleClient)(nil) + _ acp.ClientExperimental = (*exampleClient)(nil) +) + func (e *exampleClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { fmt.Printf("\n🔐 Permission requested: %s\n", params.ToolCall.Title) fmt.Println("\nOptions:") @@ -168,8 +173,11 @@ func main() { // Initialize initResp, err := conn.Initialize(acp.InitializeRequest{ - ProtocolVersion: acp.ProtocolVersionNumber, - ClientCapabilities: acp.ClientCapabilities{Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}, + ProtocolVersion: acp.ProtocolVersionNumber, + ClientCapabilities: acp.ClientCapabilities{ + Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}, + Terminal: true, + }, }) if err != nil { fmt.Fprintf(os.Stderr, "initialize error: %v\n", err) diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index b431801..6d47f8a 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -19,7 +19,10 @@ type replClient struct { autoApprove bool } -var _ acp.Client = (*replClient)(nil) +var ( + _ acp.Client = (*replClient)(nil) + _ acp.ClientExperimental = (*replClient)(nil) +) func (c *replClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { if c.autoApprove { @@ -64,7 +67,7 @@ func (c *replClient) SessionUpdate(params acp.SessionNotification) error { case u.AgentMessageChunk != nil: content := u.AgentMessageChunk.Content if content.Type == "text" && content.Text != nil { - fmt.Printf("[agent] \n%s\n", content.Text.Text) + fmt.Printf("%s", content.Text.Text) } else { fmt.Printf("[agent] %s\n", content.Type) } @@ -194,8 +197,11 @@ func main() { // Initialize initResp, err := conn.Initialize(acp.InitializeRequest{ - ProtocolVersion: acp.ProtocolVersionNumber, - ClientCapabilities: acp.ClientCapabilities{Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}, + ProtocolVersion: acp.ProtocolVersionNumber, + ClientCapabilities: acp.ClientCapabilities{ + Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}, + Terminal: true, + }, }) if err != nil { fmt.Fprintf(os.Stderr, "initialize error: %v\n", err) @@ -216,7 +222,7 @@ func main() { fmt.Println("Type a message and press Enter to send. Commands: :cancel, :exit") scanner := bufio.NewScanner(os.Stdin) for { - fmt.Print("> ") + fmt.Print("\n> ") if !scanner.Scan() { break } diff --git a/go/types.go b/go/types.go index 4f86231..7f7ca7a 100644 --- a/go/types.go +++ b/go/types.go @@ -1086,16 +1086,25 @@ type Agent interface { Authenticate(params AuthenticateRequest) error Initialize(params InitializeRequest) (InitializeResponse, error) Cancel(params CancelNotification) error - LoadSession(params LoadSessionRequest) error NewSession(params NewSessionRequest) (NewSessionResponse, error) Prompt(params PromptRequest) (PromptResponse, error) } + +// AgentLoader defines optional support for loading sessions. Implement and advertise the capability to enable 'session/load'. +type AgentLoader interface { + LoadSession(params LoadSessionRequest) error +} type Client interface { ReadTextFile(params ReadTextFileRequest) (ReadTextFileResponse, error) WriteTextFile(params WriteTextFileRequest) error RequestPermission(params RequestPermissionRequest) (RequestPermissionResponse, error) SessionUpdate(params SessionNotification) error +} + +// ClientExperimental defines undocumented/experimental methods (x-docs-ignore), such as terminal support. Implement and advertise the related capability to enable them. +type ClientExperimental interface { CreateTerminal(params CreateTerminalRequest) (CreateTerminalResponse, error) TerminalOutput(params TerminalOutputRequest) (TerminalOutputResponse, error) ReleaseTerminal(params ReleaseTerminalRequest) error + WaitForTerminalExit(params WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) } From 25d4360a924c93c3a3da6a658955b4bb87c78a2a Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Sat, 30 Aug 2025 14:43:47 +0200 Subject: [PATCH 04/22] feat: split terminal methods into separate ClientTerminal interface Change-Id: I6f5b08f6e93cd8fc5b904ab91014c206647a4aca Signed-off-by: Thomas Kosiewski --- go/client_gen.go | 8 ++++---- go/cmd/generate/main.go | 21 +++++++++++++++++---- go/example/client/main.go | 4 ++-- go/example/gemini/main.go | 4 ++-- go/types.go | 4 ++-- 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/go/client_gen.go b/go/client_gen.go index 09f592c..9ea61aa 100644 --- a/go/client_gen.go +++ b/go/client_gen.go @@ -64,7 +64,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - t, ok := c.client.(ClientExperimental) + t, ok := c.client.(ClientTerminal) if !ok { return nil, NewMethodNotFound(method) } @@ -81,7 +81,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - t, ok := c.client.(ClientExperimental) + t, ok := c.client.(ClientTerminal) if !ok { return nil, NewMethodNotFound(method) } @@ -98,7 +98,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - t, ok := c.client.(ClientExperimental) + t, ok := c.client.(ClientTerminal) if !ok { return nil, NewMethodNotFound(method) } @@ -114,7 +114,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - t, ok := c.client.(ClientExperimental) + t, ok := c.client.(ClientTerminal) if !ok { return nil, NewMethodNotFound(method) } diff --git a/go/cmd/generate/main.go b/go/cmd/generate/main.go index 720801c..744bac6 100644 --- a/go/cmd/generate/main.go +++ b/go/cmd/generate/main.go @@ -338,6 +338,7 @@ func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { // Client clientStable := []Code{} clientExperimental := []Code{} + clientTerminal := []Code{} cmKeys := make([]string, 0, len(meta.ClientMethods)) for k := range meta.ClientMethods { cmKeys = append(cmKeys, k) @@ -351,7 +352,11 @@ func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { } target := &clientStable if isDocsIgnored(mi) { - target = &clientExperimental + if strings.HasPrefix(wire, "terminal/") { + target = &clientTerminal + } else { + target = &clientExperimental + } } if mi.Notif != "" { name := dispatchMethodNameForNotification(k, mi.Notif) @@ -367,8 +372,12 @@ func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { } } f.Type().Id("Client").Interface(clientStable...) + if len(clientTerminal) > 0 { + f.Comment("ClientTerminal defines terminal-related experimental methods (x-docs-ignore). Implement and advertise 'terminal: true' to enable 'terminal/*'.") + f.Type().Id("ClientTerminal").Interface(clientTerminal...) + } if len(clientExperimental) > 0 { - f.Comment("ClientExperimental defines undocumented/experimental methods (x-docs-ignore), such as terminal support. Implement and advertise the related capability to enable them.") + f.Comment("ClientExperimental defines undocumented/experimental methods (x-docs-ignore) other than terminals. These may change or be removed without notice.") f.Type().Id("ClientExperimental").Interface(clientExperimental...) } } @@ -1352,11 +1361,15 @@ func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { ), ) methodName := strings.TrimSuffix(mi.Req, "Request") - // Optional/experimental undocumented methods: require ClientExperimental + // Optional/experimental undocumented methods: require ClientTerminal for terminal/*, ClientExperimental otherwise if isDocsIgnoredMethod(schema, mi) { + clientIface := "ClientExperimental" + if strings.HasPrefix(wire, "terminal/") { + clientIface = "ClientTerminal" + } // Perform type assertion first, then branch body = append(body, - List(Id("t"), Id("ok")).Op(":=").Id("c").Dot("client").Assert(Id("ClientExperimental")), + List(Id("t"), Id("ok")).Op(":=").Id("c").Dot("client").Assert(Id(clientIface)), If(Op("!").Id("ok")).Block( Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))), ), diff --git a/go/example/client/main.go b/go/example/client/main.go index e204488..996684d 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -14,8 +14,8 @@ import ( type exampleClient struct{} var ( - _ acp.Client = (*exampleClient)(nil) - _ acp.ClientExperimental = (*exampleClient)(nil) + _ acp.Client = (*exampleClient)(nil) + _ acp.ClientTerminal = (*exampleClient)(nil) ) func (e *exampleClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index 6d47f8a..b3667ad 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -20,8 +20,8 @@ type replClient struct { } var ( - _ acp.Client = (*replClient)(nil) - _ acp.ClientExperimental = (*replClient)(nil) + _ acp.Client = (*replClient)(nil) + _ acp.ClientTerminal = (*replClient)(nil) ) func (c *replClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { diff --git a/go/types.go b/go/types.go index 7f7ca7a..549ecc4 100644 --- a/go/types.go +++ b/go/types.go @@ -1101,8 +1101,8 @@ type Client interface { SessionUpdate(params SessionNotification) error } -// ClientExperimental defines undocumented/experimental methods (x-docs-ignore), such as terminal support. Implement and advertise the related capability to enable them. -type ClientExperimental interface { +// ClientTerminal defines terminal-related experimental methods (x-docs-ignore). Implement and advertise 'terminal: true' to enable 'terminal/*'. +type ClientTerminal interface { CreateTerminal(params CreateTerminalRequest) (CreateTerminalResponse, error) TerminalOutput(params TerminalOutputRequest) (TerminalOutputResponse, error) ReleaseTerminal(params ReleaseTerminalRequest) error From 53f0c7f1c2bf5421996fbcad27eb19325610af29 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Sun, 31 Aug 2025 15:19:30 +0200 Subject: [PATCH 05/22] feat: improve error handling with JSON-formatted RequestError output and add Claude Code example Change-Id: I7ee9a6217c3c33ef27e5143acf8e2f9e17ded3dc Signed-off-by: Thomas Kosiewski --- go/errors.go | 30 +++- go/example/claude-code/main.go | 265 +++++++++++++++++++++++++++++++++ go/example/client/main.go | 32 +++- go/example/gemini/main.go | 32 +++- package-lock.json | 4 +- 5 files changed, 346 insertions(+), 17 deletions(-) create mode 100644 go/example/claude-code/main.go diff --git a/go/errors.go b/go/errors.go index e9263a0..c937b91 100644 --- a/go/errors.go +++ b/go/errors.go @@ -1,5 +1,10 @@ package acp +import ( + "encoding/json" + "fmt" +) + // RequestError represents a JSON-RPC error response. type RequestError struct { Code int `json:"code"` @@ -7,7 +12,30 @@ type RequestError struct { Data any `json:"data,omitempty"` } -func (e *RequestError) Error() string { return e.Message } +func (e *RequestError) Error() string { + // Prefer a structured, JSON-style string so callers get details by default + // similar to the TypeScript client. + // Example: {"code":-32603,"message":"Internal error","data":{"details":"..."}} + if e == nil { + return "" + } + // Try to pretty-print compact JSON for stability in logs. + type view struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` + } + v := view{Code: e.Code, Message: e.Message, Data: e.Data} + b, err := json.Marshal(v) + if err == nil { + return string(b) + } + // Fallback if marshal fails. + if e.Data != nil { + return fmt.Sprintf("code %d: %s (data: %v)", e.Code, e.Message, e.Data) + } + return fmt.Sprintf("code %d: %s", e.Code, e.Message) +} func NewParseError(data any) *RequestError { return &RequestError{Code: -32700, Message: "Parse error", Data: data} diff --git a/go/example/claude-code/main.go b/go/example/claude-code/main.go new file mode 100644 index 0000000..3f28947 --- /dev/null +++ b/go/example/claude-code/main.go @@ -0,0 +1,265 @@ +package main + +import ( + "bufio" + "encoding/json" + "flag" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + + acp "github.com/zed-industries/agent-client-protocol/go" +) + +// ClaudeCodeREPL demonstrates connecting to the Claude Code CLI running in ACP mode +// and providing a simple REPL to send prompts and print streamed updates. + +type replClient struct { + autoApprove bool +} + +var _ acp.Client = (*replClient)(nil) + +func (c *replClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { + if c.autoApprove { + // Prefer an allow option if present; otherwise choose the first option. + for _, o := range params.Options { + if o.Kind == acp.PermissionOptionKindAllowOnce || o.Kind == acp.PermissionOptionKindAllowAlways { + return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: o.OptionId}}}, nil + } + } + if len(params.Options) > 0 { + return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: params.Options[0].OptionId}}}, nil + } + return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Cancelled: &acp.RequestPermissionOutcomeCancelled{}}}, nil + } + + fmt.Printf("\n🔐 Permission requested: %s\n", params.ToolCall.Title) + fmt.Println("\nOptions:") + for i, opt := range params.Options { + fmt.Printf(" %d. %s (%s)\n", i+1, opt.Name, opt.Kind) + } + reader := bufio.NewReader(os.Stdin) + for { + fmt.Printf("\nChoose an option: ") + line, _ := reader.ReadString('\n') + line = strings.TrimSpace(line) + if line == "" { + continue + } + idx := -1 + fmt.Sscanf(line, "%d", &idx) + idx = idx - 1 + if idx >= 0 && idx < len(params.Options) { + return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: params.Options[idx].OptionId}}}, nil + } + fmt.Println("Invalid option. Please try again.") + } +} + +func (c *replClient) SessionUpdate(params acp.SessionNotification) error { + u := params.Update + switch { + case u.AgentMessageChunk != nil: + content := u.AgentMessageChunk.Content + if content.Type == "text" && content.Text != nil { + fmt.Printf("[agent] \n%s\n", content.Text.Text) + } else { + fmt.Printf("[agent] %s\n", content.Type) + } + case u.ToolCall != nil: + fmt.Printf("\n🔧 %s (%s)\n", u.ToolCall.Title, u.ToolCall.Status) + case u.ToolCallUpdate != nil: + fmt.Printf("\n🔧 Tool call `%s` updated: %v\n\n", u.ToolCallUpdate.ToolCallId, u.ToolCallUpdate.Status) + case u.Plan != nil: + fmt.Println("[plan update]") + case u.AgentThoughtChunk != nil: + thought := u.AgentThoughtChunk.Content + if thought.Type == "text" && thought.Text != nil { + fmt.Printf("[agent_thought_chunk] \n%s\n", thought.Text.Text) + } else { + fmt.Println("[agent_thought_chunk]", "(", thought.Type, ")") + } + case u.UserMessageChunk != nil: + fmt.Println("[user_message_chunk]") + } + return nil +} + +func (c *replClient) WriteTextFile(params acp.WriteTextFileRequest) error { + if !filepath.IsAbs(params.Path) { + return fmt.Errorf("path must be absolute: %s", params.Path) + } + dir := filepath.Dir(params.Path) + if dir != "" { + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + } + if err := os.WriteFile(params.Path, []byte(params.Content), 0o644); err != nil { + return fmt.Errorf("write %s: %w", params.Path, err) + } + fmt.Printf("[Client] Wrote %d bytes to %s\n", len(params.Content), params.Path) + return nil +} + +func (c *replClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { + if !filepath.IsAbs(params.Path) { + return acp.ReadTextFileResponse{}, fmt.Errorf("path must be absolute: %s", params.Path) + } + b, err := os.ReadFile(params.Path) + if err != nil { + return acp.ReadTextFileResponse{}, fmt.Errorf("read %s: %w", params.Path, err) + } + content := string(b) + if params.Line > 0 || params.Limit > 0 { + lines := strings.Split(content, "\n") + start := 0 + if params.Line > 0 { + start = min(max(params.Line-1, 0), len(lines)) + } + end := len(lines) + if params.Limit > 0 { + if start+params.Limit < end { + end = start + params.Limit + } + } + content = strings.Join(lines[start:end], "\n") + } + fmt.Printf("[Client] ReadTextFile: %s (%d bytes)\n", params.Path, len(content)) + return acp.ReadTextFileResponse{Content: content}, nil +} + +// Optional/UNSTABLE terminal methods: implement as no-ops for example +func (c *replClient) CreateTerminal(params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { + fmt.Printf("[Client] CreateTerminal: %v\n", params) + return acp.CreateTerminalResponse{TerminalId: "term-1"}, nil +} + +func (c *replClient) TerminalOutput(params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { + fmt.Printf("[Client] TerminalOutput: %v\n", params) + return acp.TerminalOutputResponse{Output: "", Truncated: false}, nil +} + +func (c *replClient) ReleaseTerminal(params acp.ReleaseTerminalRequest) error { + fmt.Printf("[Client] ReleaseTerminal: %v\n", params) + return nil +} + +func (c *replClient) WaitForTerminalExit(params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { + fmt.Printf("[Client] WaitForTerminalExit: %v\n", params) + return acp.WaitForTerminalExitResponse{}, nil +} + +func main() { + yolo := flag.Bool("yolo", false, "Auto-approve permission prompts") + flag.Parse() + + // Invoke Claude Code via npx + cmd := exec.Command("npx", "-y", "@zed-industries/claude-code-acp") + cmd.Stderr = os.Stderr + stdin, err := cmd.StdinPipe() + if err != nil { + fmt.Fprintf(os.Stderr, "stdin pipe error: %v\n", err) + os.Exit(1) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + fmt.Fprintf(os.Stderr, "stdout pipe error: %v\n", err) + os.Exit(1) + } + + if err := cmd.Start(); err != nil { + fmt.Fprintf(os.Stderr, "failed to start Claude Code: %v\n", err) + os.Exit(1) + } + + client := &replClient{autoApprove: *yolo} + conn := acp.NewClientSideConnection(client, stdin, stdout) + + // Initialize + initResp, err := conn.Initialize(acp.InitializeRequest{ + ProtocolVersion: acp.ProtocolVersionNumber, + ClientCapabilities: acp.ClientCapabilities{Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}, + }) + if err != nil { + if re, ok := err.(*acp.RequestError); ok { + if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { + fmt.Fprintf(os.Stderr, "[Client] Error: %s\n", string(b)) + } else { + fmt.Fprintf(os.Stderr, "initialize error (%d): %s\n", re.Code, re.Message) + } + } else { + fmt.Fprintf(os.Stderr, "initialize error: %v\n", err) + } + _ = cmd.Process.Kill() + os.Exit(1) + } + fmt.Printf("✅ Connected to Claude Code (protocol v%v)\n", initResp.ProtocolVersion) + + // New session + newSess, err := conn.NewSession(acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) + if err != nil { + if re, ok := err.(*acp.RequestError); ok { + if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { + fmt.Fprintf(os.Stderr, "[Client] Error: %s\n", string(b)) + } else { + fmt.Fprintf(os.Stderr, "newSession error (%d): %s\n", re.Code, re.Message) + } + } else { + fmt.Fprintf(os.Stderr, "newSession error: %v\n", err) + } + _ = cmd.Process.Kill() + os.Exit(1) + } + fmt.Printf("📝 Created session: %s\n", newSess.SessionId) + + fmt.Println("Type a message and press Enter to send. Commands: :cancel, :exit") + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Print("> ") + if !scanner.Scan() { + break + } + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + switch line { + case ":exit", ":quit": + _ = cmd.Process.Kill() + return + case ":cancel": + _ = conn.Cancel(acp.CancelNotification(newSess)) + continue + } + // Send prompt and wait for completion while streaming updates are printed via SessionUpdate + if _, err := conn.Prompt(acp.PromptRequest{ + SessionId: newSess.SessionId, + Prompt: []acp.ContentBlock{{Type: "text", Text: &acp.TextContent{Text: line}}}, + }); err != nil { + // If it's a JSON-RPC RequestError, surface more detail for troubleshooting + if re, ok := err.(*acp.RequestError); ok { + if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { + fmt.Fprintf(os.Stderr, "[Client] Error: %s\n", string(b)) + } else { + fmt.Fprintf(os.Stderr, "prompt error (%d): %s\n", re.Code, re.Message) + } + } else { + fmt.Fprintf(os.Stderr, "prompt error: %v\n", err) + } + } + } + + _ = cmd.Process.Kill() +} + +func mustCwd() string { + wd, err := os.Getwd() + if err != nil { + return "." + } + return wd +} diff --git a/go/example/client/main.go b/go/example/client/main.go index 996684d..b8cb317 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "encoding/json" "fmt" "os" "os/exec" @@ -14,8 +15,8 @@ import ( type exampleClient struct{} var ( - _ acp.Client = (*exampleClient)(nil) - _ acp.ClientTerminal = (*exampleClient)(nil) + _ acp.Client = (*exampleClient)(nil) + _ acp.ClientTerminal = (*exampleClient)(nil) ) func (e *exampleClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { @@ -180,7 +181,15 @@ func main() { }, }) if err != nil { - fmt.Fprintf(os.Stderr, "initialize error: %v\n", err) + if re, ok := err.(*acp.RequestError); ok { + if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { + fmt.Fprintf(os.Stderr, "[Client] Error: %s\n", string(b)) + } else { + fmt.Fprintf(os.Stderr, "initialize error (%d): %s\n", re.Code, re.Message) + } + } else { + fmt.Fprintf(os.Stderr, "initialize error: %v\n", err) + } _ = cmd.Process.Kill() os.Exit(1) } @@ -189,7 +198,15 @@ func main() { // New session newSess, err := conn.NewSession(acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) if err != nil { - fmt.Fprintf(os.Stderr, "newSession error: %v\n", err) + if re, ok := err.(*acp.RequestError); ok { + if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { + fmt.Fprintf(os.Stderr, "[Client] Error: %s\n", string(b)) + } else { + fmt.Fprintf(os.Stderr, "newSession error (%d): %s\n", re.Code, re.Message) + } + } else { + fmt.Fprintf(os.Stderr, "newSession error: %v\n", err) + } _ = cmd.Process.Kill() os.Exit(1) } @@ -206,9 +223,10 @@ func main() { }}, }); err != nil { if re, ok := err.(*acp.RequestError); ok { - fmt.Fprintf(os.Stderr, "prompt error (%d): %s\n", re.Code, re.Message) - if re.Data != nil { - fmt.Fprintf(os.Stderr, "details: %v\n", re.Data) + if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { + fmt.Fprintf(os.Stderr, "[Client] Error: %s\n", string(b)) + } else { + fmt.Fprintf(os.Stderr, "prompt error (%d): %s\n", re.Code, re.Message) } } else { fmt.Fprintf(os.Stderr, "prompt error: %v\n", err) diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index b3667ad..982b04d 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "encoding/json" "flag" "fmt" "os" @@ -20,8 +21,8 @@ type replClient struct { } var ( - _ acp.Client = (*replClient)(nil) - _ acp.ClientTerminal = (*replClient)(nil) + _ acp.Client = (*replClient)(nil) + _ acp.ClientTerminal = (*replClient)(nil) ) func (c *replClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { @@ -204,7 +205,15 @@ func main() { }, }) if err != nil { - fmt.Fprintf(os.Stderr, "initialize error: %v\n", err) + if re, ok := err.(*acp.RequestError); ok { + if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { + fmt.Fprintf(os.Stderr, "[Client] Error: %s\n", string(b)) + } else { + fmt.Fprintf(os.Stderr, "initialize error (%d): %s\n", re.Code, re.Message) + } + } else { + fmt.Fprintf(os.Stderr, "initialize error: %v\n", err) + } _ = cmd.Process.Kill() os.Exit(1) } @@ -213,7 +222,15 @@ func main() { // New session newSess, err := conn.NewSession(acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) if err != nil { - fmt.Fprintf(os.Stderr, "newSession error: %v\n", err) + if re, ok := err.(*acp.RequestError); ok { + if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { + fmt.Fprintf(os.Stderr, "[Client] Error: %s\n", string(b)) + } else { + fmt.Fprintf(os.Stderr, "newSession error (%d): %s\n", re.Code, re.Message) + } + } else { + fmt.Fprintf(os.Stderr, "newSession error: %v\n", err) + } _ = cmd.Process.Kill() os.Exit(1) } @@ -245,9 +262,10 @@ func main() { }); err != nil { // If it's a JSON-RPC RequestError, surface more detail for troubleshooting if re, ok := err.(*acp.RequestError); ok { - fmt.Fprintf(os.Stderr, "prompt error (%d): %s\n", re.Code, re.Message) - if re.Data != nil { - fmt.Fprintf(os.Stderr, "details: %v\n", re.Data) + if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { + fmt.Fprintf(os.Stderr, "[Client] Error: %s\n", string(b)) + } else { + fmt.Fprintf(os.Stderr, "prompt error (%d): %s\n", re.Code, re.Message) } } else { fmt.Fprintf(os.Stderr, "prompt error: %v\n", err) diff --git a/package-lock.json b/package-lock.json index ea931f9..bc7d648 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@zed-industries/agent-client-protocol", - "version": "0.1.3-alpha.0", + "version": "0.2.0-alpha.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@zed-industries/agent-client-protocol", - "version": "0.1.3-alpha.0", + "version": "0.2.0-alpha.0", "license": "Apache-2.0", "dependencies": { "zod": "^3.0.0" From 56a191b44d3c14eedd0331d137d7d4073b7bf44e Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Sun, 31 Aug 2025 15:58:17 +0200 Subject: [PATCH 06/22] feat: refactor code generation into modular emit package with dispatch handlers Change-Id: Ie1670abd513fab42dedaa2c9e55362840f2f9985 Signed-off-by: Thomas Kosiewski --- go/cmd/generate/internal/emit/constants.go | 63 + go/cmd/generate/internal/emit/dispatch.go | 221 +++ .../internal/emit/dispatch_helpers.go | 78 + go/cmd/generate/internal/emit/jenwrap.go | 36 + go/cmd/generate/internal/emit/types.go | 756 ++++++++ go/cmd/generate/internal/ir/ir.go | 248 +++ go/cmd/generate/internal/load/load.go | 64 + go/cmd/generate/internal/util/util.go | 76 + go/cmd/generate/main.go | 1549 +---------------- go/{constants.go => constants_gen.go} | 0 go/{types.go => types_gen.go} | 0 11 files changed, 1566 insertions(+), 1525 deletions(-) create mode 100644 go/cmd/generate/internal/emit/constants.go create mode 100644 go/cmd/generate/internal/emit/dispatch.go create mode 100644 go/cmd/generate/internal/emit/dispatch_helpers.go create mode 100644 go/cmd/generate/internal/emit/jenwrap.go create mode 100644 go/cmd/generate/internal/emit/types.go create mode 100644 go/cmd/generate/internal/ir/ir.go create mode 100644 go/cmd/generate/internal/load/load.go create mode 100644 go/cmd/generate/internal/util/util.go rename go/{constants.go => constants_gen.go} (100%) rename go/{types.go => types_gen.go} (100%) diff --git a/go/cmd/generate/internal/emit/constants.go b/go/cmd/generate/internal/emit/constants.go new file mode 100644 index 0000000..7074772 --- /dev/null +++ b/go/cmd/generate/internal/emit/constants.go @@ -0,0 +1,63 @@ +package emit + +import ( + "bytes" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/load" + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/util" +) + +// WriteConstantsJen writes the version and method constants to constants_gen.go. +func WriteConstantsJen(outDir string, meta *load.Meta) error { + f := NewFile("acp") + f.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") + f.Comment("ProtocolVersionNumber is the ACP protocol version supported by this SDK.") + f.Const().Id("ProtocolVersionNumber").Op("=").Lit(meta.Version) + + // Agent methods + amKeys := make([]string, 0, len(meta.AgentMethods)) + for k := range meta.AgentMethods { + amKeys = append(amKeys, k) + } + sort.Strings(amKeys) + var agentDefs []Code + for _, k := range amKeys { + wire := meta.AgentMethods[k] + agentDefs = append(agentDefs, Id("AgentMethod"+toExportedConst(k)).Op("=").Lit(wire)) + } + f.Comment("Agent method names") + f.Const().Defs(agentDefs...) + + // Client methods + cmKeys := make([]string, 0, len(meta.ClientMethods)) + for k := range meta.ClientMethods { + cmKeys = append(cmKeys, k) + } + sort.Strings(cmKeys) + var clientDefs []Code + for _, k := range cmKeys { + wire := meta.ClientMethods[k] + clientDefs = append(clientDefs, Id("ClientMethod"+toExportedConst(k)).Op("=").Lit(wire)) + } + f.Comment("Client method names") + f.Const().Defs(clientDefs...) + + var buf bytes.Buffer + if err := f.Render(&buf); err != nil { + return err + } + return os.WriteFile(filepath.Join(outDir, "constants_gen.go"), buf.Bytes(), 0o644) +} + +// Helpers kept private to this package (copy from original main) +func toExportedConst(s string) string { + parts := strings.Split(s, "_") + for i := range parts { + parts[i] = util.TitleWord(parts[i]) + } + return strings.Join(parts, "") +} diff --git a/go/cmd/generate/internal/emit/dispatch.go b/go/cmd/generate/internal/emit/dispatch.go new file mode 100644 index 0000000..3a144cd --- /dev/null +++ b/go/cmd/generate/internal/emit/dispatch.go @@ -0,0 +1,221 @@ +package emit + +import ( + "bytes" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/ir" + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/load" +) + +// WriteDispatchJen emits agent_gen.go and client_gen.go with handlers and wrappers. +func WriteDispatchJen(outDir string, schema *load.Schema, meta *load.Meta) error { + groups := ir.BuildMethodGroups(schema, meta) + + // Agent handler + outbound wrappers + fAgent := NewFile("acp") + fAgent.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") + + amKeys := make([]string, 0, len(meta.AgentMethods)) + for k := range meta.AgentMethods { + amKeys = append(amKeys, k) + } + sort.Strings(amKeys) + switchCases := []Code{} + for _, k := range amKeys { + wire := meta.AgentMethods[k] + mi := groups["agent|"+wire] + if mi == nil { + continue + } + caseBody := []Code{} + if mi.Notif != "" { + caseBody = append(caseBody, jUnmarshalValidate(mi.Notif)...) + callName := ir.DispatchMethodNameForNotification(k, mi.Notif) + caseBody = append(caseBody, jCallNotification("a.agent", callName)...) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + caseBody = append(caseBody, jUnmarshalValidate(mi.Req)...) + methodName := strings.TrimSuffix(mi.Req, "Request") + pre, recv := jAgentAssert(mi.Binding) + if pre != nil { + caseBody = append(caseBody, pre...) + } + if ir.IsNullResponse(schema.Defs[respName]) { + caseBody = append(caseBody, jCallRequestNoResp(recv, methodName)...) + } else { + caseBody = append(caseBody, jCallRequestWithResp(recv, methodName, respName)...) + } + } + if len(caseBody) > 0 { + switchCases = append(switchCases, Case(Id("AgentMethod"+toExportedConst(k))).Block(caseBody...)) + } + } + switchCases = append(switchCases, Default().Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))))) + fAgent.Func().Params(Id("a").Op("*").Id("AgentSideConnection")).Id("handle").Params( + Id("method").String(), Id("params").Qual("encoding/json", "RawMessage")). + Params(Any(), Op("*").Id("RequestError")). + Block(Switch(Id("method")).Block(switchCases...)) + + // Agent outbound wrappers (agent -> client) + agentConst := map[string]string{} + for k, v := range meta.AgentMethods { + agentConst[v] = "AgentMethod" + toExportedConst(k) + } + clientConst := map[string]string{} + for k, v := range meta.ClientMethods { + clientConst[v] = "ClientMethod" + toExportedConst(k) + } + + cmKeys2 := make([]string, 0, len(meta.ClientMethods)) + for k := range meta.ClientMethods { + cmKeys2 = append(cmKeys2, k) + } + sort.Strings(cmKeys2) + for _, k := range cmKeys2 { + wire := meta.ClientMethods[k] + mi := groups["client|"+wire] + if mi == nil { + continue + } + constName := clientConst[mi.Method] + if constName == "" { + continue + } + if mi.Notif != "" { + name := strings.TrimSuffix(mi.Notif, "Notification") + switch mi.Method { + case "session/update": + name = "SessionUpdate" + case "session/cancel": + name = "Cancel" + } + fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(name).Params(Id("params").Id(mi.Notif)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id(constName), Id("params")))) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + if ir.IsNullResponse(schema.Defs[respName]) { + fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("params").Id(mi.Req)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id(constName), Id("params")))) + } else { + fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("params").Id(mi.Req)).Params(Id(respName), Error()). + Block( + List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id(constName), Id("params")), + Return(Id("resp"), Id("err")), + ) + } + } + } + var bufA bytes.Buffer + if err := fAgent.Render(&bufA); err != nil { + return err + } + if err := os.WriteFile(filepath.Join(outDir, "agent_gen.go"), bufA.Bytes(), 0o644); err != nil { + return err + } + + // Client handler + outbound wrappers + fClient := NewFile("acp") + fClient.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") + cmKeys := make([]string, 0, len(meta.ClientMethods)) + for k := range meta.ClientMethods { + cmKeys = append(cmKeys, k) + } + sort.Strings(cmKeys) + cCases := []Code{} + for _, k := range cmKeys { + wire := meta.ClientMethods[k] + mi := groups["client|"+wire] + if mi == nil { + continue + } + body := []Code{} + if mi.Notif != "" { + body = append(body, jUnmarshalValidate(mi.Notif)...) + pre, recv := jClientAssert(mi.Binding) + if pre != nil { + body = append(body, pre...) + } + callName := ir.DispatchMethodNameForNotification(k, mi.Notif) + body = append(body, jCallNotification(recv, callName)...) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + body = append(body, jUnmarshalValidate(mi.Req)...) + methodName := strings.TrimSuffix(mi.Req, "Request") + pre, recv := jClientAssert(mi.Binding) + if pre != nil { + body = append(body, pre...) + } + if ir.IsNullResponse(schema.Defs[respName]) { + body = append(body, jCallRequestNoResp(recv, methodName)...) + } else { + body = append(body, jCallRequestWithResp(recv, methodName, respName)...) + } + } + if len(body) > 0 { + cCases = append(cCases, Case(Id("ClientMethod"+toExportedConst(k))).Block(body...)) + } + } + cCases = append(cCases, Default().Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))))) + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id("handle").Params( + Id("method").String(), Id("params").Qual("encoding/json", "RawMessage")). + Params(Any(), Op("*").Id("RequestError")). + Block(Switch(Id("method")).Block(cCases...)) + + // Client outbound wrappers (client -> agent) + amKeys2 := make([]string, 0, len(meta.AgentMethods)) + for k := range meta.AgentMethods { + amKeys2 = append(amKeys2, k) + } + sort.Strings(amKeys2) + for _, k := range amKeys2 { + wire := meta.AgentMethods[k] + mi := groups["agent|"+wire] + if mi == nil { + continue + } + constName := agentConst[mi.Method] + if constName == "" { + continue + } + if mi.Notif != "" { + name := strings.TrimSuffix(mi.Notif, "Notification") + switch mi.Method { + case "session/update": + name = "SessionUpdate" + case "session/cancel": + name = "Cancel" + } + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(name).Params(Id("params").Id(mi.Notif)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id(constName), Id("params")))) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + if ir.IsNullResponse(schema.Defs[respName]) { + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("params").Id(mi.Req)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id(constName), Id("params")))) + } else { + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("params").Id(mi.Req)).Params(Id(respName), Error()). + Block( + List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id(constName), Id("params")), + Return(Id("resp"), Id("err")), + ) + } + } + } + var bufC bytes.Buffer + if err := fClient.Render(&bufC); err != nil { + return err + } + if err := os.WriteFile(filepath.Join(outDir, "client_gen.go"), bufC.Bytes(), 0o644); err != nil { + return err + } + + return nil +} diff --git a/go/cmd/generate/internal/emit/dispatch_helpers.go b/go/cmd/generate/internal/emit/dispatch_helpers.go new file mode 100644 index 0000000..98c7ed3 --- /dev/null +++ b/go/cmd/generate/internal/emit/dispatch_helpers.go @@ -0,0 +1,78 @@ +package emit + +import ( + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/ir" +) + +// invInvalid: return invalid params with compact json-like message +func jInvInvalid() Code { + return Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Any().Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))) +} + +// retToReqErr: wrap error to JSON-RPC request error +func jRetToReqErr() Code { return Return(Nil(), Id("toReqErr").Call(Id("err"))) } + +// jUnmarshalValidate emits var p T; json.Unmarshal; p.Validate +func jUnmarshalValidate(typeName string) []Code { + return []Code{ + Var().Id("p").Id(typeName), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), Id("err").Op("!=").Nil()). + Block(jInvInvalid()), + If(List(Id("err")).Op(":=").Id("p").Dot("Validate").Call(), Id("err").Op("!=").Nil()). + Block(jInvInvalid()), + } +} + +// jAgentAssert returns prelude for interface assertions and the receiver name. +func jAgentAssert(binding ir.MethodBinding) ([]Code, string) { + switch binding { + case ir.BindAgentLoader: + return []Code{ + List(Id("loader"), Id("ok")).Op(":=").Id("a").Dot("agent").Assert(Id("AgentLoader")), + If(Op("!").Id("ok")).Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method")))), + }, "loader" + case ir.BindAgentExperimental: + return []Code{ + List(Id("exp"), Id("ok")).Op(":=").Id("a").Dot("agent").Assert(Id("AgentExperimental")), + If(Op("!").Id("ok")).Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method")))), + }, "exp" + default: + return nil, "a.agent" + } +} + +// jClientAssert returns prelude for interface assertions and the receiver name. +func jClientAssert(binding ir.MethodBinding) ([]Code, string) { + switch binding { + case ir.BindClientTerminal: + return []Code{ + List(Id("t"), Id("ok")).Op(":=").Id("c").Dot("client").Assert(Id("ClientTerminal")), + If(Op("!").Id("ok")).Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method")))), + }, "t" + default: + return nil, "c.client" + } +} + +// Request call emitters for handlers +func jCallRequestNoResp(recv, methodName string) []Code { + return []Code{ + If(List(Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block(jRetToReqErr()), + Return(Nil(), Nil()), + } +} + +func jCallRequestWithResp(recv, methodName, respType string) []Code { + return []Code{ + List(Id("resp"), Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("p")), + If(Id("err").Op("!=").Nil()).Block(jRetToReqErr()), + Return(Id("resp"), Nil()), + } +} + +func jCallNotification(recv, methodName string) []Code { + return []Code{ + If(List(Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block(jRetToReqErr()), + Return(Nil(), Nil()), + } +} diff --git a/go/cmd/generate/internal/emit/jenwrap.go b/go/cmd/generate/internal/emit/jenwrap.go new file mode 100644 index 0000000..23e7123 --- /dev/null +++ b/go/cmd/generate/internal/emit/jenwrap.go @@ -0,0 +1,36 @@ +package emit + +import jen "github.com/dave/jennifer/jen" + +// Local aliases to avoid dot-importing jennifer while keeping concise calls. +type ( + Code = jen.Code + Dict = jen.Dict + Group = jen.Group + File = jen.File +) + +var ( + NewFile = jen.NewFile + Id = jen.Id + Lit = jen.Lit + Return = jen.Return + Nil = jen.Nil + String = jen.String + Int = jen.Int + Float64 = jen.Float64 + Bool = jen.Bool + Any = jen.Any + Map = jen.Map + Index = jen.Index + Qual = jen.Qual + Error = jen.Error + Case = jen.Case + Default = jen.Default + Switch = jen.Switch + Var = jen.Var + If = jen.If + List = jen.List + Op = jen.Op + Comment = jen.Comment +) diff --git a/go/cmd/generate/internal/emit/types.go b/go/cmd/generate/internal/emit/types.go new file mode 100644 index 0000000..d5bdc53 --- /dev/null +++ b/go/cmd/generate/internal/emit/types.go @@ -0,0 +1,756 @@ +package emit + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/ir" + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/load" + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/util" +) + +// WriteTypesJen emits go/types_gen.go with all types and the Agent/Client interfaces. +func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { + f := NewFile("acp") + f.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") + + // Deterministic order + keys := make([]string, 0, len(schema.Defs)) + for k := range schema.Defs { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, name := range keys { + def := schema.Defs[name] + if def == nil { + continue + } + + if def.Description != "" { + f.Comment(util.SanitizeComment(def.Description)) + } + + switch { + case len(def.Enum) > 0: + f.Type().Id(name).String() + defs := []Code{} + for _, v := range def.Enum { + s := fmt.Sprint(v) + defs = append(defs, Id(util.ToEnumConst(name, s)).Id(name).Op("=").Lit(s)) + } + if len(defs) > 0 { + f.Const().Defs(defs...) + } + f.Line() + case isStringConstUnion(def): + f.Type().Id(name).String() + defs := []Code{} + for _, v := range def.OneOf { + if v != nil && v.Const != nil { + s := fmt.Sprint(v.Const) + defs = append(defs, Id(util.ToEnumConst(name, s)).Id(name).Op("=").Lit(s)) + } + } + if len(defs) > 0 { + f.Const().Defs(defs...) + } + f.Line() + case name == "ContentBlock": + emitContentBlockJen(f) + case name == "ToolCallContent": + emitToolCallContentJen(f) + case name == "EmbeddedResourceResource": + emitEmbeddedResourceResourceJen(f) + case name == "RequestPermissionOutcome": + emitRequestPermissionOutcomeJen(f) + case name == "SessionUpdate": + emitSessionUpdateJen(f) + case ir.PrimaryType(def) == "object" && len(def.Properties) > 0: + st := []Code{} + req := map[string]struct{}{} + for _, r := range def.Required { + req[r] = struct{}{} + } + pkeys := make([]string, 0, len(def.Properties)) + for pk := range def.Properties { + pkeys = append(pkeys, pk) + } + sort.Strings(pkeys) + for _, pk := range pkeys { + prop := def.Properties[pk] + field := util.ToExportedField(pk) + if prop.Description != "" { + st = append(st, Comment(util.SanitizeComment(prop.Description))) + } + tag := pk + if _, ok := req[pk]; !ok { + tag = pk + ",omitempty" + } + st = append(st, Id(field).Add(jenTypeForOptional(prop)).Tag(map[string]string{"json": tag})) + } + f.Type().Id(name).Struct(st...) + f.Line() + case ir.PrimaryType(def) == "string" || ir.PrimaryType(def) == "integer" || ir.PrimaryType(def) == "number" || ir.PrimaryType(def) == "boolean": + f.Type().Id(name).Add(primitiveJenType(ir.PrimaryType(def))) + f.Line() + default: + f.Comment(fmt.Sprintf("%s is a union or complex schema; represented generically.", name)) + f.Type().Id(name).Any() + f.Line() + } + + // validators for selected types + if strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification") || name == "ContentBlock" || name == "ToolCallContent" || name == "SessionUpdate" || name == "ToolCallUpdate" { + emitValidateJen(f, name, def) + } + } + + // Append Agent & Client interfaces from method groups + groups := ir.BuildMethodGroups(schema, meta) + + // Agent + agentMethods := []Code{} + agentLoaderMethods := []Code{} + agentExperimentalMethods := []Code{} + for _, k := range ir.SortedKeys(meta.AgentMethods) { + wire := meta.AgentMethods[k] + mi := groups["agent|"+wire] + if mi == nil { + continue + } + target := &agentMethods + switch mi.Binding { + case ir.BindAgentLoader: + target = &agentLoaderMethods + case ir.BindAgentExperimental: + target = &agentExperimentalMethods + } + if mi.Notif != "" { + name := ir.DispatchMethodNameForNotification(k, mi.Notif) + *target = append(*target, Id(name).Params(Id("params").Id(mi.Notif)).Error()) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + methodName := strings.TrimSuffix(mi.Req, "Request") + if ir.IsNullResponse(schema.Defs[respName]) { + *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) + } else { + *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) + } + } + } + f.Type().Id("Agent").Interface(agentMethods...) + if len(agentLoaderMethods) > 0 { + f.Comment("AgentLoader defines optional support for loading sessions. Implement and advertise the capability to enable 'session/load'.") + f.Type().Id("AgentLoader").Interface(agentLoaderMethods...) + } + if len(agentExperimentalMethods) > 0 { + f.Comment("AgentExperimental defines undocumented/experimental methods (x-docs-ignore). These may change or be removed without notice.") + f.Type().Id("AgentExperimental").Interface(agentExperimentalMethods...) + } + + // Client + clientStable := []Code{} + clientExperimental := []Code{} + clientTerminal := []Code{} + for _, k := range ir.SortedKeys(meta.ClientMethods) { + wire := meta.ClientMethods[k] + mi := groups["client|"+wire] + if mi == nil { + continue + } + target := &clientStable + switch mi.Binding { + case ir.BindClientExperimental: + target = &clientExperimental + case ir.BindClientTerminal: + target = &clientTerminal + } + if mi.Notif != "" { + name := ir.DispatchMethodNameForNotification(k, mi.Notif) + *target = append(*target, Id(name).Params(Id("params").Id(mi.Notif)).Error()) + } else if mi.Req != "" { + respName := strings.TrimSuffix(mi.Req, "Request") + "Response" + methodName := strings.TrimSuffix(mi.Req, "Request") + if ir.IsNullResponse(schema.Defs[respName]) { + *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) + } else { + *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) + } + } + } + f.Type().Id("Client").Interface(clientStable...) + if len(clientTerminal) > 0 { + f.Comment("ClientTerminal defines terminal-related experimental methods (x-docs-ignore). Implement and advertise 'terminal: true' to enable 'terminal/*'.") + f.Type().Id("ClientTerminal").Interface(clientTerminal...) + } + if len(clientExperimental) > 0 { + f.Comment("ClientExperimental defines undocumented/experimental methods (x-docs-ignore) other than terminals. These may change or be removed without notice.") + f.Type().Id("ClientExperimental").Interface(clientExperimental...) + } + + var buf bytes.Buffer + if err := f.Render(&buf); err != nil { + return err + } + return os.WriteFile(filepath.Join(outDir, "types_gen.go"), buf.Bytes(), 0o644) +} + +func isStringConstUnion(def *load.Definition) bool { + if def == nil || len(def.OneOf) == 0 { + return false + } + for _, v := range def.OneOf { + if v == nil || v.Const == nil { + return false + } + if _, ok := v.Const.(string); !ok { + return false + } + } + return true +} + +// emitValidateJen generates validators for selected types (logic unchanged). +func emitValidateJen(f *File, name string, def *load.Definition) { + switch name { + case "ContentBlock": + f.Func().Params(Id("c").Op("*").Id("ContentBlock")).Id("Validate").Params().Params(Error()).Block( + Switch(Id("c").Dot("Type")).Block( + Case(Lit("text")).Block(If(Id("c").Dot("Text").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.text missing"))))), + Case(Lit("image")).Block(If(Id("c").Dot("Image").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.image missing"))))), + Case(Lit("audio")).Block(If(Id("c").Dot("Audio").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.audio missing"))))), + Case(Lit("resource_link")).Block(If(Id("c").Dot("ResourceLink").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.resource_link missing"))))), + Case(Lit("resource")).Block(If(Id("c").Dot("Resource").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.resource missing"))))), + ), + Return(Nil()), + ) + return + case "ToolCallContent": + f.Func().Params(Id("t").Op("*").Id("ToolCallContent")).Id("Validate").Params().Params(Error()).Block( + Switch(Id("t").Dot("Type")).Block( + Case(Lit("content")).Block(If(Id("t").Dot("Content").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.content missing"))))), + Case(Lit("diff")).Block(If(Id("t").Dot("Diff").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.diff missing"))))), + Case(Lit("terminal")).Block(If(Id("t").Dot("Terminal").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.terminal missing"))))), + ), + Return(Nil()), + ) + return + case "SessionUpdate": + f.Func().Params(Id("s").Op("*").Id("SessionUpdate")).Id("Validate").Params().Params(Error()).Block( + Var().Id("count").Int(), + If(Id("s").Dot("UserMessageChunk").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("AgentMessageChunk").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("AgentThoughtChunk").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("ToolCall").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("ToolCallUpdate").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("s").Dot("Plan").Op("!=").Nil()).Block(Id("count").Op("++")), + If(Id("count").Op("!=").Lit(1)).Block(Return(Qual("fmt", "Errorf").Call(Lit("sessionupdate must have exactly one variant set")))), + Return(Nil()), + ) + return + case "ToolCallUpdate": + f.Func().Params(Id("t").Op("*").Id("ToolCallUpdate")).Id("Validate").Params().Params(Error()).Block( + If(Id("t").Dot("ToolCallId").Op("==").Lit("")).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolCallId is required")))), + Return(Nil()), + ) + return + } + if def != nil && ir.PrimaryType(def) == "object" { + if !(strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification")) { + return + } + f.Func().Params(Id("v").Op("*").Id(name)).Id("Validate").Params().Params(Error()).BlockFunc(func(g *Group) { + pkeys := make([]string, 0, len(def.Properties)) + for pk := range def.Properties { + pkeys = append(pkeys, pk) + } + sort.Strings(pkeys) + for _, propName := range pkeys { + pDef := def.Properties[propName] + required := false + for _, r := range def.Required { + if r == propName { + required = true + break + } + } + field := util.ToExportedField(propName) + if required { + switch ir.PrimaryType(pDef) { + case "string": + g.If(Id("v").Dot(field).Op("==").Lit("")).Block(Return(Qual("fmt", "Errorf").Call(Lit(propName + " is required")))) + case "array": + g.If(Id("v").Dot(field).Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit(propName + " is required")))) + } + } + } + g.Return(Nil()) + }) + } +} + +// Type mapping helpers (unchanged behavior vs original) +func primitiveJenType(t string) Code { + switch t { + case "string": + return String() + case "integer": + return Int() + case "number": + return Float64() + case "boolean": + return Bool() + default: + return Any() + } +} + +func jenTypeFor(d *load.Definition) Code { + if d == nil { + return Any() + } + if d.Ref != "" { + if strings.HasPrefix(d.Ref, "#/$defs/") { + return Id(d.Ref[len("#/$defs/"):]) + } + return Any() + } + if len(d.Enum) > 0 { + return String() + } + switch ir.PrimaryType(d) { + case "string": + return String() + case "integer": + return Int() + case "number": + return Float64() + case "boolean": + return Bool() + case "array": + return Index().Add(jenTypeFor(d.Items)) + case "object": + if len(d.Properties) == 0 { + return Map(String()).Any() + } + return Map(String()).Any() + default: + if len(d.AnyOf) > 0 || len(d.OneOf) > 0 { + return Any() + } + return Any() + } +} + +// jenTypeForOptional maps unions that include null to pointer types where applicable. +func jenTypeForOptional(d *load.Definition) Code { + if d == nil { + return Any() + } + list := d.AnyOf + if len(list) == 0 { + list = d.OneOf + } + if len(list) == 2 { + var nonNull *load.Definition + for _, e := range list { + if e == nil { + continue + } + if s, ok := e.Type.(string); ok && s == "null" { + continue + } + if e.Const != nil { + nn := *e + nn.Type = "string" + nonNull = &nn + } else { + nonNull = e + } + } + if nonNull != nil { + if nonNull.Ref != "" && strings.HasPrefix(nonNull.Ref, "#/$defs/") { + return Op("*").Id(nonNull.Ref[len("#/$defs/"):]) + } + switch ir.PrimaryType(nonNull) { + case "string": + return Op("*").String() + case "integer": + return Op("*").Int() + case "number": + return Op("*").Float64() + case "boolean": + return Op("*").Bool() + } + } + } + return jenTypeFor(d) +} + +// Specialized emitters copied from original (unchanged behavior). +func emitContentBlockJen(f *File) { + f.Type().Id("ResourceLinkContent").Struct( + Id("Annotations").Any().Tag(map[string]string{"json": "annotations,omitempty"}), + Id("Description").Op("*").String().Tag(map[string]string{"json": "description,omitempty"}), + Id("MimeType").Op("*").String().Tag(map[string]string{"json": "mimeType,omitempty"}), + Id("Name").String().Tag(map[string]string{"json": "name"}), + Id("Size").Op("*").Int64().Tag(map[string]string{"json": "size,omitempty"}), + Id("Title").Op("*").String().Tag(map[string]string{"json": "title,omitempty"}), + Id("Uri").String().Tag(map[string]string{"json": "uri"}), + ) + f.Line() + f.Type().Id("ContentBlock").Struct( + Id("Type").String().Tag(map[string]string{"json": "type"}), + Id("Text").Op("*").Id("TextContent").Tag(map[string]string{"json": "-"}), + Id("Image").Op("*").Id("ImageContent").Tag(map[string]string{"json": "-"}), + Id("Audio").Op("*").Id("AudioContent").Tag(map[string]string{"json": "-"}), + Id("ResourceLink").Op("*").Id("ResourceLinkContent").Tag(map[string]string{"json": "-"}), + Id("Resource").Op("*").Id("EmbeddedResource").Tag(map[string]string{"json": "-"}), + ) + f.Line() + f.Func().Params(Id("c").Op("*").Id("ContentBlock")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + Var().Id("probe").Struct(Id("Type").String().Tag(map[string]string{"json": "type"})), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("probe")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Type").Op("=").Id("probe").Dot("Type"), + Switch(Id("probe").Dot("Type")).Block( + Case(Lit("text")).Block( + Var().Id("v").Id("TextContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Text").Op("=").Op("&").Id("v"), + ), + Case(Lit("image")).Block( + Var().Id("v").Id("ImageContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Image").Op("=").Op("&").Id("v"), + ), + Case(Lit("audio")).Block( + Var().Id("v").Id("AudioContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Audio").Op("=").Op("&").Id("v"), + ), + Case(Lit("resource_link")).Block( + Var().Id("v").Id("ResourceLinkContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("ResourceLink").Op("=").Op("&").Id("v"), + ), + Case(Lit("resource")).Block( + Var().Id("v").Id("EmbeddedResource"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("c").Dot("Resource").Op("=").Op("&").Id("v"), + ), + ), + Return(Nil()), + ) + f.Func().Params(Id("c").Id("ContentBlock")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( + Switch(Id("c").Dot("Type")).Block( + Case(Lit("text")).Block( + If(Id("c").Dot("Text").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("text"), + Lit("text"): Id("c").Dot("Text").Dot("Text"), + }))), + ), + ), + Case(Lit("image")).Block( + If(Id("c").Dot("Image").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("image"), + Lit("data"): Id("c").Dot("Image").Dot("Data"), + Lit("mimeType"): Id("c").Dot("Image").Dot("MimeType"), + Lit("uri"): Id("c").Dot("Image").Dot("Uri"), + }))), + ), + ), + Case(Lit("audio")).Block( + If(Id("c").Dot("Audio").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("audio"), + Lit("data"): Id("c").Dot("Audio").Dot("Data"), + Lit("mimeType"): Id("c").Dot("Audio").Dot("MimeType"), + }))), + ), + ), + Case(Lit("resource_link")).Block( + If(Id("c").Dot("ResourceLink").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("resource_link"), + Lit("name"): Id("c").Dot("ResourceLink").Dot("Name"), + Lit("uri"): Id("c").Dot("ResourceLink").Dot("Uri"), + Lit("description"): Id("c").Dot("ResourceLink").Dot("Description"), + Lit("mimeType"): Id("c").Dot("ResourceLink").Dot("MimeType"), + Lit("size"): Id("c").Dot("ResourceLink").Dot("Size"), + Lit("title"): Id("c").Dot("ResourceLink").Dot("Title"), + }))), + ), + ), + Case(Lit("resource")).Block( + If(Id("c").Dot("Resource").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("type"): Lit("resource"), + Lit("resource"): Id("c").Dot("Resource").Dot("Resource"), + }))), + ), + ), + ), + Return(Index().Byte().Values(), Nil()), + ) + f.Line() +} + +func emitToolCallContentJen(f *File) { + f.Type().Id("DiffContent").Struct( + Id("NewText").String().Tag(map[string]string{"json": "newText"}), + Id("OldText").Op("*").String().Tag(map[string]string{"json": "oldText,omitempty"}), + Id("Path").String().Tag(map[string]string{"json": "path"}), + ) + f.Type().Id("TerminalRef").Struct(Id("TerminalId").String().Tag(map[string]string{"json": "terminalId"})) + f.Line() + f.Type().Id("ToolCallContent").Struct( + Id("Type").String().Tag(map[string]string{"json": "type"}), + Id("Content").Op("*").Id("ContentBlock").Tag(map[string]string{"json": "-"}), + Id("Diff").Op("*").Id("DiffContent").Tag(map[string]string{"json": "-"}), + Id("Terminal").Op("*").Id("TerminalRef").Tag(map[string]string{"json": "-"}), + ) + f.Line() + f.Func().Params(Id("t").Op("*").Id("ToolCallContent")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + Var().Id("probe").Struct(Id("Type").String().Tag(map[string]string{"json": "type"})), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("probe")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("t").Dot("Type").Op("=").Id("probe").Dot("Type"), + Switch(Id("probe").Dot("Type")).Block( + Case(Lit("content")).Block( + Var().Id("v").Struct( + Id("Type").String().Tag(map[string]string{"json": "type"}), + Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), + ), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("t").Dot("Content").Op("=").Op("&").Id("v").Dot("Content"), + ), + Case(Lit("diff")).Block( + Var().Id("v").Id("DiffContent"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("t").Dot("Diff").Op("=").Op("&").Id("v"), + ), + Case(Lit("terminal")).Block( + Var().Id("v").Id("TerminalRef"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("t").Dot("Terminal").Op("=").Op("&").Id("v"), + ), + ), + Return(Nil()), + ) + f.Line() +} + +func emitEmbeddedResourceResourceJen(f *File) { + f.Type().Id("EmbeddedResourceResource").Struct( + Id("TextResourceContents").Op("*").Id("TextResourceContents").Tag(map[string]string{"json": "-"}), + Id("BlobResourceContents").Op("*").Id("BlobResourceContents").Tag(map[string]string{"json": "-"}), + ) + f.Line() + f.Func().Params(Id("e").Op("*").Id("EmbeddedResourceResource")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + If(List(Id("_"), Id("ok")).Op(":=").Id("m").Index(Lit("text")), Id("ok")).Block( + Var().Id("v").Id("TextResourceContents"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("e").Dot("TextResourceContents").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + If(List(Id("_"), Id("ok2")).Op(":=").Id("m").Index(Lit("blob")), Id("ok2")).Block( + Var().Id("v").Id("BlobResourceContents"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("e").Dot("BlobResourceContents").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Return(Nil()), + ) + f.Line() +} + +func emitRequestPermissionOutcomeJen(f *File) { + f.Type().Id("RequestPermissionOutcomeCancelled").Struct() + f.Type().Id("RequestPermissionOutcomeSelected").Struct( + Id("OptionId").Id("PermissionOptionId").Tag(map[string]string{"json": "optionId"}), + ) + f.Line() + f.Type().Id("RequestPermissionOutcome").Struct( + Id("Cancelled").Op("*").Id("RequestPermissionOutcomeCancelled").Tag(map[string]string{"json": "-"}), + Id("Selected").Op("*").Id("RequestPermissionOutcomeSelected").Tag(map[string]string{"json": "-"}), + ) + f.Func().Params(Id("o").Op("*").Id("RequestPermissionOutcome")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Var().Id("outcome").String(), + If(List(Id("v"), Id("ok")).Op(":=").Id("m").Index(Lit("outcome")), Id("ok")).Block( + Qual("encoding/json", "Unmarshal").Call(Id("v"), Op("&").Id("outcome")), + ), + Switch(Id("outcome")).Block( + Case(Lit("cancelled")).Block( + Id("o").Dot("Cancelled").Op("=").Op("&").Id("RequestPermissionOutcomeCancelled").Values(), + Return(Nil()), + ), + Case(Lit("selected")).Block( + Var().Id("v2").Id("RequestPermissionOutcomeSelected"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v2")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("o").Dot("Selected").Op("=").Op("&").Id("v2"), + Return(Nil()), + ), + ), + Return(Nil()), + ) + f.Func().Params(Id("o").Id("RequestPermissionOutcome")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( + If(Id("o").Dot("Cancelled").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{Lit("outcome"): Lit("cancelled")}))), + ), + If(Id("o").Dot("Selected").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("optionId"): Id("o").Dot("Selected").Dot("OptionId"), + Lit("outcome"): Lit("selected"), + }))), + ), + Return(Index().Byte().Values(), Nil()), + ) + f.Line() +} + +func emitSessionUpdateJen(f *File) { + f.Type().Id("SessionUpdateUserMessageChunk").Struct(Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"})) + f.Type().Id("SessionUpdateAgentMessageChunk").Struct(Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"})) + f.Type().Id("SessionUpdateAgentThoughtChunk").Struct(Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"})) + f.Type().Id("SessionUpdateToolCall").Struct( + Id("Content").Index().Id("ToolCallContent").Tag(map[string]string{"json": "content,omitempty"}), + Id("Kind").Id("ToolKind").Tag(map[string]string{"json": "kind,omitempty"}), + Id("Locations").Index().Id("ToolCallLocation").Tag(map[string]string{"json": "locations,omitempty"}), + Id("RawInput").Any().Tag(map[string]string{"json": "rawInput,omitempty"}), + Id("RawOutput").Any().Tag(map[string]string{"json": "rawOutput,omitempty"}), + Id("Status").Id("ToolCallStatus").Tag(map[string]string{"json": "status,omitempty"}), + Id("Title").String().Tag(map[string]string{"json": "title"}), + Id("ToolCallId").Id("ToolCallId").Tag(map[string]string{"json": "toolCallId"}), + ) + f.Type().Id("SessionUpdateToolCallUpdate").Struct( + Id("Content").Index().Id("ToolCallContent").Tag(map[string]string{"json": "content,omitempty"}), + Id("Kind").Any().Tag(map[string]string{"json": "kind,omitempty"}), + Id("Locations").Index().Id("ToolCallLocation").Tag(map[string]string{"json": "locations,omitempty"}), + Id("RawInput").Any().Tag(map[string]string{"json": "rawInput,omitempty"}), + Id("RawOutput").Any().Tag(map[string]string{"json": "rawOutput,omitempty"}), + Id("Status").Any().Tag(map[string]string{"json": "status,omitempty"}), + Id("Title").Op("*").String().Tag(map[string]string{"json": "title,omitempty"}), + Id("ToolCallId").Id("ToolCallId").Tag(map[string]string{"json": "toolCallId"}), + ) + f.Type().Id("SessionUpdatePlan").Struct(Id("Entries").Index().Id("PlanEntry").Tag(map[string]string{"json": "entries"})) + f.Line() + f.Type().Id("SessionUpdate").Struct( + Id("UserMessageChunk").Op("*").Id("SessionUpdateUserMessageChunk").Tag(map[string]string{"json": "-"}), + Id("AgentMessageChunk").Op("*").Id("SessionUpdateAgentMessageChunk").Tag(map[string]string{"json": "-"}), + Id("AgentThoughtChunk").Op("*").Id("SessionUpdateAgentThoughtChunk").Tag(map[string]string{"json": "-"}), + Id("ToolCall").Op("*").Id("SessionUpdateToolCall").Tag(map[string]string{"json": "-"}), + Id("ToolCallUpdate").Op("*").Id("SessionUpdateToolCallUpdate").Tag(map[string]string{"json": "-"}), + Id("Plan").Op("*").Id("SessionUpdatePlan").Tag(map[string]string{"json": "-"}), + ) + f.Func().Params(Id("s").Op("*").Id("SessionUpdate")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( + Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Var().Id("kind").String(), + If(List(Id("v"), Id("ok")).Op(":=").Id("m").Index(Lit("sessionUpdate")), Id("ok")).Block( + Qual("encoding/json", "Unmarshal").Call(Id("v"), Op("&").Id("kind")), + ), + Switch(Id("kind")).Block( + Case(Lit("user_message_chunk")).Block( + Var().Id("v").Id("SessionUpdateUserMessageChunk"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("UserMessageChunk").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("agent_message_chunk")).Block( + Var().Id("v").Id("SessionUpdateAgentMessageChunk"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("AgentMessageChunk").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("agent_thought_chunk")).Block( + Var().Id("v").Id("SessionUpdateAgentThoughtChunk"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("AgentThoughtChunk").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("tool_call")).Block( + Var().Id("v").Id("SessionUpdateToolCall"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("ToolCall").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("tool_call_update")).Block( + Var().Id("v").Id("SessionUpdateToolCallUpdate"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("ToolCallUpdate").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + Case(Lit("plan")).Block( + Var().Id("v").Id("SessionUpdatePlan"), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("s").Dot("Plan").Op("=").Op("&").Id("v"), + Return(Nil()), + ), + ), + Return(Nil()), + ) + f.Func().Params(Id("s").Id("SessionUpdate")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( + If(Id("s").Dot("UserMessageChunk").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("user_message_chunk"), + Lit("content"): Id("s").Dot("UserMessageChunk").Dot("Content"), + }))), + ), + If(Id("s").Dot("AgentMessageChunk").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("agent_message_chunk"), + Lit("content"): Id("s").Dot("AgentMessageChunk").Dot("Content"), + }))), + ), + If(Id("s").Dot("AgentThoughtChunk").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("agent_thought_chunk"), + Lit("content"): Id("s").Dot("AgentThoughtChunk").Dot("Content"), + }))), + ), + If(Id("s").Dot("ToolCall").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("tool_call"), + Lit("content"): Id("s").Dot("ToolCall").Dot("Content"), + Lit("kind"): Id("s").Dot("ToolCall").Dot("Kind"), + Lit("locations"): Id("s").Dot("ToolCall").Dot("Locations"), + Lit("rawInput"): Id("s").Dot("ToolCall").Dot("RawInput"), + Lit("rawOutput"): Id("s").Dot("ToolCall").Dot("RawOutput"), + Lit("status"): Id("s").Dot("ToolCall").Dot("Status"), + Lit("title"): Id("s").Dot("ToolCall").Dot("Title"), + Lit("toolCallId"): Id("s").Dot("ToolCall").Dot("ToolCallId"), + }))), + ), + If(Id("s").Dot("ToolCallUpdate").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("tool_call_update"), + Lit("content"): Id("s").Dot("ToolCallUpdate").Dot("Content"), + Lit("kind"): Id("s").Dot("ToolCallUpdate").Dot("Kind"), + Lit("locations"): Id("s").Dot("ToolCallUpdate").Dot("Locations"), + Lit("rawInput"): Id("s").Dot("ToolCallUpdate").Dot("RawInput"), + Lit("rawOutput"): Id("s").Dot("ToolCallUpdate").Dot("RawOutput"), + Lit("status"): Id("s").Dot("ToolCallUpdate").Dot("Status"), + Lit("title"): Id("s").Dot("ToolCallUpdate").Dot("Title"), + Lit("toolCallId"): Id("s").Dot("ToolCallUpdate").Dot("ToolCallId"), + }))), + ), + If(Id("s").Dot("Plan").Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ + Lit("sessionUpdate"): Lit("plan"), + Lit("entries"): Id("s").Dot("Plan").Dot("Entries"), + }))), + ), + Return(Index().Byte().Values(), Nil()), + ) + f.Line() +} diff --git a/go/cmd/generate/internal/ir/ir.go b/go/cmd/generate/internal/ir/ir.go new file mode 100644 index 0000000..bd62e43 --- /dev/null +++ b/go/cmd/generate/internal/ir/ir.go @@ -0,0 +1,248 @@ +package ir + +import ( + "sort" + "strings" + + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/load" + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/util" +) + +// MethodBinding describes which interface a method belongs to on each side. +type MethodBinding int + +const ( + BindUnknown MethodBinding = iota + // Agent bindings + BindAgent + BindAgentLoader + BindAgentExperimental + // Client bindings + BindClient + BindClientExperimental + BindClientTerminal +) + +// MethodInfo captures association between a wire method and its Go types and binding. +type MethodInfo struct { + Side string // "agent" or "client" + Method string // wire method, e.g., "session/new" + MethodKey string // meta key, e.g., "session_new" + Req string // Go type name of Request + Resp string // Go type name of Response + Notif string // Go type name of Notification + Binding MethodBinding + DocsIgnored bool +} + +// Groups is a map keyed by side|wire to MethodInfo. +type Groups map[string]*MethodInfo + +func key(side, method string) string { return side + "|" + method } + +// PrimaryType mirrors logic from generator: find primary type string from a Definition. +func PrimaryType(d *load.Definition) string { + if d == nil || d.Type == nil { + return "" + } + switch v := d.Type.(type) { + case string: + return v + case []any: + var first string + for _, e := range v { + if s, ok := e.(string); ok { + if first == "" { + first = s + } + if s != "null" { + return s + } + } + } + return first + default: + return "" + } +} + +// IsNullResponse returns true if the response schema is explicitly null or missing. +func IsNullResponse(def *load.Definition) bool { + if def == nil { + return true + } + if s, ok := def.Type.(string); ok && s == "null" { + return true + } + return false +} + +// BuildMethodGroups merges schema-provided links with meta fallback and returns groups. +func BuildMethodGroups(schema *load.Schema, meta *load.Meta) Groups { + groups := Groups{} + // From schema + for name, def := range schema.Defs { + if def == nil || def.XMethod == "" || def.XSide == "" { + continue + } + k := key(def.XSide, def.XMethod) + mi := groups[k] + if mi == nil { + mi = &MethodInfo{Side: def.XSide, Method: def.XMethod} + groups[k] = mi + } + if strings.HasSuffix(name, "Request") { + mi.Req = name + } + if strings.HasSuffix(name, "Response") { + mi.Resp = name + } + if strings.HasSuffix(name, "Notification") { + mi.Notif = name + } + } + // From meta fallback (terminal etc.) + for mk, wire := range meta.AgentMethods { + k := key("agent", wire) + if groups[k] == nil { + base := inferTypeBaseFromMethodKey(mk) + mi := &MethodInfo{Side: "agent", Method: wire} + if wire == "session/cancel" { + mi.Notif = "CancelNotification" + } else { + if _, ok := schema.Defs[base+"Request"]; ok { + mi.Req = base + "Request" + } + if _, ok := schema.Defs[base+"Response"]; ok { + mi.Resp = base + "Response" + } + } + if mi.Req != "" || mi.Notif != "" { + groups[k] = mi + } + } + } + for mk, wire := range meta.ClientMethods { + k := key("client", wire) + if groups[k] == nil { + base := inferTypeBaseFromMethodKey(mk) + mi := &MethodInfo{Side: "client", Method: wire} + if wire == "session/update" { + mi.Notif = "SessionNotification" + } else { + if _, ok := schema.Defs[base+"Request"]; ok { + mi.Req = base + "Request" + } + if _, ok := schema.Defs[base+"Response"]; ok { + mi.Resp = base + "Response" + } + } + if mi.Req != "" || mi.Notif != "" { + groups[k] = mi + } + } + } + // Post-process bindings and docs-ignore + for _, mi := range groups { + mi.Binding = classifyBinding(schema, meta, mi) + mi.DocsIgnored = isDocsIgnoredMethod(schema, mi) + } + return groups +} + +// classifyBinding determines interface binding for each method. +func classifyBinding(schema *load.Schema, meta *load.Meta, mi *MethodInfo) MethodBinding { + if mi == nil { + return BindUnknown + } + switch mi.Side { + case "agent": + if mi.Method == "session/load" { + return BindAgentLoader + } + if isDocsIgnoredMethod(schema, mi) { + return BindAgentExperimental + } + return BindAgent + case "client": + if isDocsIgnoredMethod(schema, mi) { + if strings.HasPrefix(mi.Method, "terminal/") { + return BindClientTerminal + } + return BindClientExperimental + } + return BindClient + default: + return BindUnknown + } +} + +// isDocsIgnoredMethod if any associated type (req/resp/notif) marked x-docs-ignore. +func isDocsIgnoredMethod(schema *load.Schema, mi *MethodInfo) bool { + if mi == nil { + return false + } + if mi.Req != "" { + if d := schema.Defs[mi.Req]; d != nil && d.DocsIgnore { + return true + } + } + if mi.Resp != "" { + if d := schema.Defs[mi.Resp]; d != nil && d.DocsIgnore { + return true + } + } + if mi.Notif != "" { + if d := schema.Defs[mi.Notif]; d != nil && d.DocsIgnore { + return true + } + } + return false +} + +// inferTypeBaseFromMethodKey mirrors previous heuristic; prefer schema when available. +func inferTypeBaseFromMethodKey(methodKey string) string { + if methodKey == "terminal_wait_for_exit" { + return "WaitForTerminalExit" + } + parts := strings.Split(methodKey, "_") + if len(parts) == 2 { + n, v := parts[0], parts[1] + switch v { + case "new", "create", "release", "wait", "load", "authenticate", "prompt", "cancel", "read", "write": + return util.TitleWord(v) + util.TitleWord(n) + default: + return util.TitleWord(n) + util.TitleWord(v) + } + } + segs := strings.Split(methodKey, "_") + for i := range segs { + segs[i] = util.TitleWord(segs[i]) + } + return strings.Join(segs, "") +} + +// DispatchMethodNameForNotification deduces trait method name for notifications. +func DispatchMethodNameForNotification(methodKey, typeName string) string { + switch methodKey { + case "session_update": + return "SessionUpdate" + case "session_cancel": + return "Cancel" + default: + if strings.HasSuffix(typeName, "Notification") { + return strings.TrimSuffix(typeName, "Notification") + } + return typeName + } +} + +// SortedKeys returns sorted keys of a map. +func SortedKeys(m map[string]string) []string { + ks := make([]string, 0, len(m)) + for k := range m { + ks = append(ks, k) + } + sort.Strings(ks) + return ks +} diff --git a/go/cmd/generate/internal/load/load.go b/go/cmd/generate/internal/load/load.go new file mode 100644 index 0000000..27cbd16 --- /dev/null +++ b/go/cmd/generate/internal/load/load.go @@ -0,0 +1,64 @@ +package load + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// Meta mirrors schema/meta.json for method maps and version. +type Meta struct { + Version int `json:"version"` + AgentMethods map[string]string `json:"agentMethods"` + ClientMethods map[string]string `json:"clientMethods"` +} + +// Schema is a minimal view over schema/schema.json definitions used by the generator. +type Schema struct { + Defs map[string]*Definition `json:"$defs"` +} + +// Definition is a partial JSON Schema node the generator cares about. +type Definition struct { + Description string `json:"description"` + Type any `json:"type"` + Properties map[string]*Definition `json:"properties"` + Required []string `json:"required"` + Enum []any `json:"enum"` + Items *Definition `json:"items"` + Ref string `json:"$ref"` + AnyOf []*Definition `json:"anyOf"` + OneOf []*Definition `json:"oneOf"` + DocsIgnore bool `json:"x-docs-ignore"` + Title string `json:"title"` + Const any `json:"const"` + XSide string `json:"x-side"` + XMethod string `json:"x-method"` +} + +// ReadMeta loads schema/meta.json. +func ReadMeta(schemaDir string) (*Meta, error) { + metaBytes, err := os.ReadFile(filepath.Join(schemaDir, "meta.json")) + if err != nil { + return nil, fmt.Errorf("read meta.json: %w", err) + } + var meta Meta + if err := json.Unmarshal(metaBytes, &meta); err != nil { + return nil, fmt.Errorf("parse meta.json: %w", err) + } + return &meta, nil +} + +// ReadSchema loads schema/schema.json. +func ReadSchema(schemaDir string) (*Schema, error) { + schemaBytes, err := os.ReadFile(filepath.Join(schemaDir, "schema.json")) + if err != nil { + return nil, fmt.Errorf("read schema.json: %w", err) + } + var schema Schema + if err := json.Unmarshal(schemaBytes, &schema); err != nil { + return nil, fmt.Errorf("parse schema.json: %w", err) + } + return &schema, nil +} diff --git a/go/cmd/generate/internal/util/util.go b/go/cmd/generate/internal/util/util.go new file mode 100644 index 0000000..8912623 --- /dev/null +++ b/go/cmd/generate/internal/util/util.go @@ -0,0 +1,76 @@ +package util + +import ( + "strings" + "unicode" +) + +// SanitizeComment removes backticks and normalizes whitespace for Go comments. +func SanitizeComment(s string) string { + s = strings.ReplaceAll(s, "`", "'") + lines := strings.Split(s, "\n") + for i := range lines { + lines[i] = strings.TrimSpace(lines[i]) + } + return strings.Join(lines, " ") +} + +// TitleWord uppercases the first rune and lowercases the rest. +func TitleWord(s string) string { + if s == "" { + return s + } + r := []rune(s) + r[0] = unicode.ToUpper(r[0]) + for i := 1; i < len(r); i++ { + r[i] = unicode.ToLower(r[i]) + } + return string(r) +} + +// SplitCamel splits a camelCase string into tokens. +func SplitCamel(s string) []string { + var parts []string + last := 0 + for i := 1; i < len(s); i++ { + if isBoundary(s[i-1], s[i]) { + parts = append(parts, s[last:i]) + last = i + } + } + parts = append(parts, s[last:]) + return parts +} + +func isBoundary(prev, curr byte) bool { + return (prev >= 'a' && prev <= 'z' && curr >= 'A' && curr <= 'Z') || curr == '_' +} + +// ToExportedField converts snake_case or camelCase to PascalCase. +func ToExportedField(name string) string { + parts := strings.Split(name, "_") + if len(parts) == 1 { + parts = SplitCamel(name) + } + for i := range parts { + parts[i] = TitleWord(parts[i]) + } + return strings.Join(parts, "") +} + +// ToEnumConst builds a const identifier like . +func ToEnumConst(typeName, val string) string { + cleaned := make([]rune, 0, len(val)) + for _, r := range val { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + cleaned = append(cleaned, r) + } else { + cleaned = append(cleaned, '_') + } + } + parts := strings.FieldsFunc(string(cleaned), func(r rune) bool { return r == '_' }) + for i := range parts { + parts[i] = TitleWord(strings.ToLower(parts[i])) + } + return typeName + strings.Join(parts, "") +} diff --git a/go/cmd/generate/main.go b/go/cmd/generate/main.go index 744bac6..1b01b97 100644 --- a/go/cmd/generate/main.go +++ b/go/cmd/generate/main.go @@ -1,94 +1,58 @@ package main import ( - "bytes" - "encoding/json" - "fmt" + "flag" "os" "path/filepath" - "sort" - "strings" - "unicode" - //nolint // We intentionally use dot-import for Jennifer codegen readability - . "github.com/dave/jennifer/jen" + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/emit" + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/load" ) -type Meta struct { - Version int `json:"version"` - AgentMethods map[string]string `json:"agentMethods"` - ClientMethods map[string]string `json:"clientMethods"` -} - -type Schema struct { - Defs map[string]*Definition `json:"$defs"` -} - -type Definition struct { - Description string `json:"description"` - Type any `json:"type"` - Properties map[string]*Definition `json:"properties"` - Required []string `json:"required"` - Enum []any `json:"enum"` - Items *Definition `json:"items"` - Ref string `json:"$ref"` - AnyOf []*Definition `json:"anyOf"` - OneOf []*Definition `json:"oneOf"` - DocsIgnore bool `json:"x-docs-ignore"` - Title string `json:"title"` - Const any `json:"const"` - XSide string `json:"x-side"` - XMethod string `json:"x-method"` -} - -// methodInfo captures the association between a wire method and its Go types. -type methodInfo struct{ Side, Method, Req, Resp, Notif string } - func main() { + var schemaDirFlag string + var outDirFlag string + flag.StringVar(&schemaDirFlag, "schema", "", "path to schema directory (defaults to /schema)") + flag.StringVar(&outDirFlag, "out", "", "output directory for generated go files (defaults to /go)") + flag.Parse() + repoRoot := findRepoRoot() - schemaDir := filepath.Join(repoRoot, "schema") - outDir := filepath.Join(repoRoot, "go") + schemaDir := schemaDirFlag + outDir := outDirFlag + if schemaDir == "" { + schemaDir = filepath.Join(repoRoot, "schema") + } + if outDir == "" { + outDir = filepath.Join(repoRoot, "go") + } if err := os.MkdirAll(outDir, 0o755); err != nil { panic(err) } - // Read meta.json - metaBytes, err := os.ReadFile(filepath.Join(schemaDir, "meta.json")) + meta, err := load.ReadMeta(schemaDir) if err != nil { - panic(fmt.Errorf("read meta.json: %w", err)) - } - var meta Meta - if err := json.Unmarshal(metaBytes, &meta); err != nil { - panic(fmt.Errorf("parse meta.json: %w", err)) + panic(err) } - // Write constants.go - if err := writeConstantsJen(outDir, &meta); err != nil { + if err := emit.WriteConstantsJen(outDir, meta); err != nil { panic(err) } - // Read schema.json - schemaBytes, err := os.ReadFile(filepath.Join(schemaDir, "schema.json")) + schema, err := load.ReadSchema(schemaDir) if err != nil { - panic(fmt.Errorf("read schema.json: %w", err)) - } - var schema Schema - if err := json.Unmarshal(schemaBytes, &schema); err != nil { - panic(fmt.Errorf("parse schema.json: %w", err)) + panic(err) } - if err := writeTypesJen(outDir, &schema, &meta); err != nil { + if err := emit.WriteTypesJen(outDir, schema, meta); err != nil { panic(err) } - - if err := writeDispatchJen(outDir, &schema, &meta); err != nil { + if err := emit.WriteDispatchJen(outDir, schema, meta); err != nil { panic(err) } } func findRepoRoot() string { - // Assume this generator runs from repo root or subfolders; walk up to find package.json cwd, _ := os.Getwd() dir := cwd for i := 0; i < 10; i++ { @@ -103,1468 +67,3 @@ func findRepoRoot() string { } return cwd } - -func writeConstantsJen(outDir string, meta *Meta) error { - f := NewFile("acp") - f.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") - f.Comment("ProtocolVersionNumber is the ACP protocol version supported by this SDK.") - f.Const().Id("ProtocolVersionNumber").Op("=").Lit(meta.Version) - - // Agent methods (deterministic order) - amKeys := make([]string, 0, len(meta.AgentMethods)) - for k := range meta.AgentMethods { - amKeys = append(amKeys, k) - } - sort.Strings(amKeys) - var agentDefs []Code - for _, k := range amKeys { - wire := meta.AgentMethods[k] - agentDefs = append(agentDefs, Id("AgentMethod"+toExportedConst(k)).Op("=").Lit(wire)) - } - f.Comment("Agent method names") - f.Const().Defs(agentDefs...) - - // Client methods (deterministic order) - cmKeys := make([]string, 0, len(meta.ClientMethods)) - for k := range meta.ClientMethods { - cmKeys = append(cmKeys, k) - } - sort.Strings(cmKeys) - var clientDefs []Code - for _, k := range cmKeys { - wire := meta.ClientMethods[k] - clientDefs = append(clientDefs, Id("ClientMethod"+toExportedConst(k)).Op("=").Lit(wire)) - } - f.Comment("Client method names") - f.Const().Defs(clientDefs...) - - var buf bytes.Buffer - if err := f.Render(&buf); err != nil { - return err - } - return os.WriteFile(filepath.Join(outDir, "constants.go"), buf.Bytes(), 0o644) -} - -func toExportedConst(s string) string { - // Convert snake_case like session_new to SessionNew - parts := strings.Split(s, "_") - for i := range parts { - parts[i] = titleWord(parts[i]) - } - return strings.Join(parts, "") -} - -func writeTypesJen(outDir string, schema *Schema, meta *Meta) error { - f := NewFile("acp") - f.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") - - // Deterministic order - keys := make([]string, 0, len(schema.Defs)) - for k := range schema.Defs { - keys = append(keys, k) - } - sort.Strings(keys) - - for _, name := range keys { - def := schema.Defs[name] - if def == nil { - continue - } - - // Type-level comment - if def.Description != "" { - f.Comment(sanitizeComment(def.Description)) - } - - switch { - case len(def.Enum) > 0: - // string enum - f.Type().Id(name).String() - // const block - defs := []Code{} - for _, v := range def.Enum { - s := fmt.Sprint(v) - defs = append(defs, Id(toEnumConst(name, s)).Id(name).Op("=").Lit(s)) - } - if len(defs) > 0 { - f.Const().Defs(defs...) - } - f.Line() - case isStringConstUnion(def): - f.Type().Id(name).String() - defs := []Code{} - for _, v := range def.OneOf { - if v != nil && v.Const != nil { - s := fmt.Sprint(v.Const) - defs = append(defs, Id(toEnumConst(name, s)).Id(name).Op("=").Lit(s)) - } - } - if len(defs) > 0 { - f.Const().Defs(defs...) - } - f.Line() - case name == "ContentBlock": - emitContentBlockJen(f) - case name == "ToolCallContent": - emitToolCallContentJen(f) - case name == "EmbeddedResourceResource": - emitEmbeddedResourceResourceJen(f) - case name == "RequestPermissionOutcome": - emitRequestPermissionOutcomeJen(f) - case name == "SessionUpdate": - emitSessionUpdateJen(f) - case primaryType(def) == "object" && len(def.Properties) > 0: - // Build struct fields - st := []Code{} - // required lookup - req := map[string]struct{}{} - for _, r := range def.Required { - req[r] = struct{}{} - } - // sorted properties - pkeys := make([]string, 0, len(def.Properties)) - for pk := range def.Properties { - pkeys = append(pkeys, pk) - } - sort.Strings(pkeys) - for _, pk := range pkeys { - prop := def.Properties[pk] - field := toExportedField(pk) - // field comment must directly precede field without blank line - if prop.Description != "" { - st = append(st, Comment(sanitizeComment(prop.Description))) - } - tag := pk - if _, ok := req[pk]; !ok { - tag = pk + ",omitempty" - } - st = append(st, Id(field).Add(jenTypeForOptional(prop)).Tag(map[string]string{"json": tag})) - } - f.Type().Id(name).Struct(st...) - f.Line() - case primaryType(def) == "string" || primaryType(def) == "integer" || primaryType(def) == "number" || primaryType(def) == "boolean": - f.Type().Id(name).Add(primitiveJenType(primaryType(def))) - f.Line() - default: - // unions etc. - f.Comment(fmt.Sprintf("%s is a union or complex schema; represented generically.", name)) - f.Type().Id(name).Any() - f.Line() - } - - // Emit basic validators for RPC/union types - if strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification") || name == "ContentBlock" || name == "ToolCallContent" || name == "SessionUpdate" || name == "ToolCallUpdate" { - emitValidateJen(f, name, def) - } - } - - // Append Agent and Client interfaces derived from meta.json + schema defs - { - groups := buildMethodGroups(schema, meta) - - // Helper: determine if a method is undocumented (x-docs-ignore) - isDocsIgnored := func(mi *methodInfo) bool { - if mi == nil { - return false - } - if mi.Req != "" { - if d := schema.Defs[mi.Req]; d != nil && d.DocsIgnore { - return true - } - } - if mi.Resp != "" { - if d := schema.Defs[mi.Resp]; d != nil && d.DocsIgnore { - return true - } - } - if mi.Notif != "" { - if d := schema.Defs[mi.Notif]; d != nil && d.DocsIgnore { - return true - } - } - return false - } - - // Agent - agentMethods := []Code{} - // Optional loader methods live on a separate interface - agentLoaderMethods := []Code{} - // Undocumented/experimental methods live on a separate interface - agentExperimentalMethods := []Code{} - amKeys := make([]string, 0, len(meta.AgentMethods)) - for k := range meta.AgentMethods { - amKeys = append(amKeys, k) - } - sort.Strings(amKeys) - for _, k := range amKeys { - wire := meta.AgentMethods[k] - mi := groups["agent|"+wire] - if mi == nil { - continue - } - // Treat session/load as optional (AgentLoader) - target := &agentMethods - if wire == "session/load" { - target = &agentLoaderMethods - } - // Undocumented/experimental agent methods go to AgentExperimental - if isDocsIgnored(mi) { - target = &agentExperimentalMethods - } - if mi.Notif != "" { - name := dispatchMethodNameForNotification(k, mi.Notif) - *target = append(*target, Id(name).Params(Id("params").Id(mi.Notif)).Error()) - } else if mi.Req != "" { - respName := strings.TrimSuffix(mi.Req, "Request") + "Response" - methodName := strings.TrimSuffix(mi.Req, "Request") - if isNullResponse(schema.Defs[respName]) { - *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) - } else { - *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) - } - } - } - // Emit interfaces - f.Type().Id("Agent").Interface(agentMethods...) - if len(agentLoaderMethods) > 0 { - f.Comment("AgentLoader defines optional support for loading sessions. Implement and advertise the capability to enable 'session/load'.") - f.Type().Id("AgentLoader").Interface(agentLoaderMethods...) - } - if len(agentExperimentalMethods) > 0 { - f.Comment("AgentExperimental defines undocumented/experimental methods (x-docs-ignore). These may change or be removed without notice.") - f.Type().Id("AgentExperimental").Interface(agentExperimentalMethods...) - } - - // Client - clientStable := []Code{} - clientExperimental := []Code{} - clientTerminal := []Code{} - cmKeys := make([]string, 0, len(meta.ClientMethods)) - for k := range meta.ClientMethods { - cmKeys = append(cmKeys, k) - } - sort.Strings(cmKeys) - for _, k := range cmKeys { - wire := meta.ClientMethods[k] - mi := groups["client|"+wire] - if mi == nil { - continue - } - target := &clientStable - if isDocsIgnored(mi) { - if strings.HasPrefix(wire, "terminal/") { - target = &clientTerminal - } else { - target = &clientExperimental - } - } - if mi.Notif != "" { - name := dispatchMethodNameForNotification(k, mi.Notif) - *target = append(*target, Id(name).Params(Id("params").Id(mi.Notif)).Error()) - } else if mi.Req != "" { - respName := strings.TrimSuffix(mi.Req, "Request") + "Response" - methodName := strings.TrimSuffix(mi.Req, "Request") - if isNullResponse(schema.Defs[respName]) { - *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) - } else { - *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) - } - } - } - f.Type().Id("Client").Interface(clientStable...) - if len(clientTerminal) > 0 { - f.Comment("ClientTerminal defines terminal-related experimental methods (x-docs-ignore). Implement and advertise 'terminal: true' to enable 'terminal/*'.") - f.Type().Id("ClientTerminal").Interface(clientTerminal...) - } - if len(clientExperimental) > 0 { - f.Comment("ClientExperimental defines undocumented/experimental methods (x-docs-ignore) other than terminals. These may change or be removed without notice.") - f.Type().Id("ClientExperimental").Interface(clientExperimental...) - } - } - - var buf bytes.Buffer - if err := f.Render(&buf); err != nil { - return err - } - return os.WriteFile(filepath.Join(outDir, "types.go"), buf.Bytes(), 0o644) -} - -func isStringConstUnion(def *Definition) bool { - if def == nil || len(def.OneOf) == 0 { - return false - } - for _, v := range def.OneOf { - if v == nil || v.Const == nil { - return false - } - if _, ok := v.Const.(string); !ok { - return false - } - } - return true -} - -// isDocsIgnoredMethod returns true if any of the method's associated types -// (request, response, notification) are marked with x-docs-ignore in the schema. -func isDocsIgnoredMethod(schema *Schema, mi *methodInfo) bool { - if mi == nil { - return false - } - if mi.Req != "" { - if d := schema.Defs[mi.Req]; d != nil && d.DocsIgnore { - return true - } - } - if mi.Resp != "" { - if d := schema.Defs[mi.Resp]; d != nil && d.DocsIgnore { - return true - } - } - if mi.Notif != "" { - if d := schema.Defs[mi.Notif]; d != nil && d.DocsIgnore { - return true - } - } - return false -} - -// emitValidateJen generates a simple Validate method for selected types. -func emitValidateJen(f *File, name string, def *Definition) { - switch name { - case "ContentBlock": - f.Func().Params(Id("c").Op("*").Id("ContentBlock")).Id("Validate").Params().Params(Error()).Block( - Switch(Id("c").Dot("Type")).Block( - Case(Lit("text")).Block(If(Id("c").Dot("Text").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.text missing"))))), - Case(Lit("image")).Block(If(Id("c").Dot("Image").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.image missing"))))), - Case(Lit("audio")).Block(If(Id("c").Dot("Audio").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.audio missing"))))), - Case(Lit("resource_link")).Block(If(Id("c").Dot("ResourceLink").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.resource_link missing"))))), - Case(Lit("resource")).Block(If(Id("c").Dot("Resource").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.resource missing"))))), - ), - Return(Nil()), - ) - return - case "ToolCallContent": - f.Func().Params(Id("t").Op("*").Id("ToolCallContent")).Id("Validate").Params().Params(Error()).Block( - Switch(Id("t").Dot("Type")).Block( - Case(Lit("content")).Block(If(Id("t").Dot("Content").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.content missing"))))), - Case(Lit("diff")).Block(If(Id("t").Dot("Diff").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.diff missing"))))), - Case(Lit("terminal")).Block(If(Id("t").Dot("Terminal").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.terminal missing"))))), - ), - Return(Nil()), - ) - return - case "SessionUpdate": - f.Func().Params(Id("s").Op("*").Id("SessionUpdate")).Id("Validate").Params().Params(Error()).Block( - Var().Id("count").Int(), - If(Id("s").Dot("UserMessageChunk").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("AgentMessageChunk").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("AgentThoughtChunk").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("ToolCall").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("ToolCallUpdate").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("Plan").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("count").Op("!=").Lit(1)).Block(Return(Qual("fmt", "Errorf").Call(Lit("sessionupdate must have exactly one variant set")))), - Return(Nil()), - ) - return - case "ToolCallUpdate": - f.Func().Params(Id("t").Op("*").Id("ToolCallUpdate")).Id("Validate").Params().Params(Error()).Block( - If(Id("t").Dot("ToolCallId").Op("==").Lit("")).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolCallId is required")))), - Return(Nil()), - ) - return - } - // Generic RPC objects - if def != nil && primaryType(def) == "object" { - if !(strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification")) { - return - } - f.Func().Params(Id("v").Op("*").Id(name)).Id("Validate").Params().Params(Error()).BlockFunc(func(g *Group) { - // Emit checks in deterministic property order - pkeys := make([]string, 0, len(def.Properties)) - for pk := range def.Properties { - pkeys = append(pkeys, pk) - } - sort.Strings(pkeys) - for _, propName := range pkeys { - pDef := def.Properties[propName] - // is required? - required := false - for _, r := range def.Required { - if r == propName { - required = true - break - } - } - field := toExportedField(propName) - if required { - switch primaryType(pDef) { - case "string": - g.If(Id("v").Dot(field).Op("==").Lit("")).Block(Return(Qual("fmt", "Errorf").Call(Lit(propName + " is required")))) - case "array": - g.If(Id("v").Dot(field).Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit(propName + " is required")))) - } - } - } - g.Return(Nil()) - }) - } -} - -// buildMethodGroups merges schema-provided links with inferred ones from meta. -func buildMethodGroups(schema *Schema, meta *Meta) map[string]*methodInfo { - groups := map[string]*methodInfo{} - // From schema - for name, def := range schema.Defs { - if def == nil || def.XMethod == "" || def.XSide == "" { - continue - } - key := def.XSide + "|" + def.XMethod - mi := groups[key] - if mi == nil { - mi = &methodInfo{Side: def.XSide, Method: def.XMethod} - groups[key] = mi - } - if strings.HasSuffix(name, "Request") { - mi.Req = name - } - if strings.HasSuffix(name, "Response") { - mi.Resp = name - } - if strings.HasSuffix(name, "Notification") { - mi.Notif = name - } - } - // From meta fallback (e.g., terminal methods) - for key, wire := range meta.AgentMethods { - k := "agent|" + wire - if groups[k] == nil { - base := inferTypeBaseFromMethodKey(key) - mi := &methodInfo{Side: "agent", Method: wire} - if wire == "session/cancel" { - mi.Notif = "CancelNotification" - } else { - if _, ok := schema.Defs[base+"Request"]; ok { - mi.Req = base + "Request" - } - if _, ok := schema.Defs[base+"Response"]; ok { - mi.Resp = base + "Response" - } - } - if mi.Req != "" || mi.Notif != "" { - groups[k] = mi - } - } - } - for key, wire := range meta.ClientMethods { - k := "client|" + wire - if groups[k] == nil { - base := inferTypeBaseFromMethodKey(key) - mi := &methodInfo{Side: "client", Method: wire} - if wire == "session/update" { - mi.Notif = "SessionNotification" - } else { - if _, ok := schema.Defs[base+"Request"]; ok { - mi.Req = base + "Request" - } - if _, ok := schema.Defs[base+"Response"]; ok { - mi.Resp = base + "Response" - } - } - if mi.Req != "" || mi.Notif != "" { - groups[k] = mi - } - } - } - return groups -} - -func inferTypeBaseFromMethodKey(methodKey string) string { - // Special-case known irregular mappings - if methodKey == "terminal_wait_for_exit" { - return "WaitForTerminalExit" - } - parts := strings.Split(methodKey, "_") - if len(parts) == 2 { - n, v := parts[0], parts[1] - switch v { - case "new", "create", "release", "wait", "load", "authenticate", "prompt", "cancel", "read", "write": - return titleWord(v) + titleWord(n) - default: - return titleWord(n) + titleWord(v) - } - } - segs := strings.Split(methodKey, "_") - for i := range segs { - segs[i] = titleWord(segs[i]) - } - return strings.Join(segs, "") -} - -func emitContentBlockJen(f *File) { - // ResourceLinkContent helper - f.Type().Id("ResourceLinkContent").Struct( - Id("Annotations").Any().Tag(map[string]string{"json": "annotations,omitempty"}), - Id("Description").Op("*").String().Tag(map[string]string{"json": "description,omitempty"}), - Id("MimeType").Op("*").String().Tag(map[string]string{"json": "mimeType,omitempty"}), - Id("Name").String().Tag(map[string]string{"json": "name"}), - Id("Size").Op("*").Int64().Tag(map[string]string{"json": "size,omitempty"}), - Id("Title").Op("*").String().Tag(map[string]string{"json": "title,omitempty"}), - Id("Uri").String().Tag(map[string]string{"json": "uri"}), - ) - f.Line() - // ContentBlock - f.Type().Id("ContentBlock").Struct( - Id("Type").String().Tag(map[string]string{"json": "type"}), - Id("Text").Op("*").Id("TextContent").Tag(map[string]string{"json": "-"}), - Id("Image").Op("*").Id("ImageContent").Tag(map[string]string{"json": "-"}), - Id("Audio").Op("*").Id("AudioContent").Tag(map[string]string{"json": "-"}), - Id("ResourceLink").Op("*").Id("ResourceLinkContent").Tag(map[string]string{"json": "-"}), - Id("Resource").Op("*").Id("EmbeddedResource").Tag(map[string]string{"json": "-"}), - ) - f.Line() - // UnmarshalJSON for ContentBlock - f.Func().Params(Id("c").Op("*").Id("ContentBlock")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - Var().Id("probe").Struct(Id("Type").String().Tag(map[string]string{"json": "type"})), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("probe")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Type").Op("=").Id("probe").Dot("Type"), - Switch(Id("probe").Dot("Type")).Block( - Case(Lit("text")).Block( - Var().Id("v").Id("TextContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Text").Op("=").Op("&").Id("v"), - ), - Case(Lit("image")).Block( - Var().Id("v").Id("ImageContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Image").Op("=").Op("&").Id("v"), - ), - Case(Lit("audio")).Block( - Var().Id("v").Id("AudioContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Audio").Op("=").Op("&").Id("v"), - ), - Case(Lit("resource_link")).Block( - Var().Id("v").Id("ResourceLinkContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("ResourceLink").Op("=").Op("&").Id("v"), - ), - Case(Lit("resource")).Block( - Var().Id("v").Id("EmbeddedResource"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Resource").Op("=").Op("&").Id("v"), - ), - ), - Return(Nil()), - ) - // MarshalJSON for ContentBlock - f.Func().Params(Id("c").Id("ContentBlock")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( - Switch(Id("c").Dot("Type")).Block( - Case(Lit("text")).Block( - If(Id("c").Dot("Text").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("text"), - Lit("text"): Id("c").Dot("Text").Dot("Text"), - }))), - ), - ), - Case(Lit("image")).Block( - If(Id("c").Dot("Image").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("image"), - Lit("data"): Id("c").Dot("Image").Dot("Data"), - Lit("mimeType"): Id("c").Dot("Image").Dot("MimeType"), - Lit("uri"): Id("c").Dot("Image").Dot("Uri"), - }))), - ), - ), - Case(Lit("audio")).Block( - If(Id("c").Dot("Audio").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("audio"), - Lit("data"): Id("c").Dot("Audio").Dot("Data"), - Lit("mimeType"): Id("c").Dot("Audio").Dot("MimeType"), - }))), - ), - ), - Case(Lit("resource_link")).Block( - If(Id("c").Dot("ResourceLink").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("resource_link"), - Lit("name"): Id("c").Dot("ResourceLink").Dot("Name"), - Lit("uri"): Id("c").Dot("ResourceLink").Dot("Uri"), - Lit("description"): Id("c").Dot("ResourceLink").Dot("Description"), - Lit("mimeType"): Id("c").Dot("ResourceLink").Dot("MimeType"), - Lit("size"): Id("c").Dot("ResourceLink").Dot("Size"), - Lit("title"): Id("c").Dot("ResourceLink").Dot("Title"), - }))), - ), - ), - Case(Lit("resource")).Block( - If(Id("c").Dot("Resource").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("resource"), - Lit("resource"): Id("c").Dot("Resource").Dot("Resource"), - }))), - ), - ), - ), - Return(Index().Byte().Values(), Nil()), - ) - f.Line() -} - -func emitToolCallContentJen(f *File) { - // Helpers - f.Type().Id("DiffContent").Struct( - Id("NewText").String().Tag(map[string]string{"json": "newText"}), - Id("OldText").Op("*").String().Tag(map[string]string{"json": "oldText,omitempty"}), - Id("Path").String().Tag(map[string]string{"json": "path"}), - ) - f.Type().Id("TerminalRef").Struct(Id("TerminalId").String().Tag(map[string]string{"json": "terminalId"})) - f.Line() - // ToolCallContent - f.Type().Id("ToolCallContent").Struct( - Id("Type").String().Tag(map[string]string{"json": "type"}), - Id("Content").Op("*").Id("ContentBlock").Tag(map[string]string{"json": "-"}), - Id("Diff").Op("*").Id("DiffContent").Tag(map[string]string{"json": "-"}), - Id("Terminal").Op("*").Id("TerminalRef").Tag(map[string]string{"json": "-"}), - ) - f.Line() - // UnmarshalJSON - f.Func().Params(Id("t").Op("*").Id("ToolCallContent")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - Var().Id("probe").Struct(Id("Type").String().Tag(map[string]string{"json": "type"})), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("probe")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("t").Dot("Type").Op("=").Id("probe").Dot("Type"), - Switch(Id("probe").Dot("Type")).Block( - Case(Lit("content")).Block( - Var().Id("v").Struct( - Id("Type").String().Tag(map[string]string{"json": "type"}), - Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), - ), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("t").Dot("Content").Op("=").Op("&").Id("v").Dot("Content"), - ), - Case(Lit("diff")).Block( - Var().Id("v").Id("DiffContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("t").Dot("Diff").Op("=").Op("&").Id("v"), - ), - Case(Lit("terminal")).Block( - Var().Id("v").Id("TerminalRef"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("t").Dot("Terminal").Op("=").Op("&").Id("v"), - ), - ), - Return(Nil()), - ) - f.Line() -} - -func emitEmbeddedResourceResourceJen(f *File) { - // Holder with pointers to known variants - f.Type().Id("EmbeddedResourceResource").Struct( - Id("TextResourceContents").Op("*").Id("TextResourceContents").Tag(map[string]string{"json": "-"}), - Id("BlobResourceContents").Op("*").Id("BlobResourceContents").Tag(map[string]string{"json": "-"}), - ) - f.Line() - f.Func().Params(Id("e").Op("*").Id("EmbeddedResourceResource")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - // Decide by presence of distinguishing keys - Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - // TextResourceContents has "text" key - If(List(Id("_"), Id("ok")).Op(":=").Id("m").Index(Lit("text")), Id("ok")).Block( - Var().Id("v").Id("TextResourceContents"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("e").Dot("TextResourceContents").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - // BlobResourceContents has "blob" key - If(List(Id("_"), Id("ok2")).Op(":=").Id("m").Index(Lit("blob")), Id("ok2")).Block( - Var().Id("v").Id("BlobResourceContents"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("e").Dot("BlobResourceContents").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Return(Nil()), - ) - f.Line() -} - -func emitRequestPermissionOutcomeJen(f *File) { - // Variants - f.Type().Id("RequestPermissionOutcomeCancelled").Struct() - f.Type().Id("RequestPermissionOutcomeSelected").Struct( - Id("OptionId").Id("PermissionOptionId").Tag(map[string]string{"json": "optionId"}), - ) - f.Line() - // Holder - f.Type().Id("RequestPermissionOutcome").Struct( - Id("Cancelled").Op("*").Id("RequestPermissionOutcomeCancelled").Tag(map[string]string{"json": "-"}), - Id("Selected").Op("*").Id("RequestPermissionOutcomeSelected").Tag(map[string]string{"json": "-"}), - ) - f.Func().Params(Id("o").Op("*").Id("RequestPermissionOutcome")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Var().Id("outcome").String(), - If(List(Id("v"), Id("ok")).Op(":=").Id("m").Index(Lit("outcome")), Id("ok")).Block( - Qual("encoding/json", "Unmarshal").Call(Id("v"), Op("&").Id("outcome")), - ), - Switch(Id("outcome")).Block( - Case(Lit("cancelled")).Block( - Id("o").Dot("Cancelled").Op("=").Op("&").Id("RequestPermissionOutcomeCancelled").Values(), - Return(Nil()), - ), - Case(Lit("selected")).Block( - Var().Id("v2").Id("RequestPermissionOutcomeSelected"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v2")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("o").Dot("Selected").Op("=").Op("&").Id("v2"), - Return(Nil()), - ), - ), - Return(Nil()), - ) - // MarshalJSON - f.Func().Params(Id("o").Id("RequestPermissionOutcome")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( - If(Id("o").Dot("Cancelled").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{Lit("outcome"): Lit("cancelled")}))), - ), - If(Id("o").Dot("Selected").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("outcome"): Lit("selected"), - Lit("optionId"): Id("o").Dot("Selected").Dot("OptionId"), - }))), - ), - Return(Index().Byte().Values(), Nil()), - ) - f.Line() -} - -func emitSessionUpdateJen(f *File) { - // Variant types - f.Type().Id("SessionUpdateUserMessageChunk").Struct( - Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), - ) - f.Type().Id("SessionUpdateAgentMessageChunk").Struct( - Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), - ) - f.Type().Id("SessionUpdateAgentThoughtChunk").Struct( - Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), - ) - f.Type().Id("SessionUpdateToolCall").Struct( - Id("Content").Index().Id("ToolCallContent").Tag(map[string]string{"json": "content,omitempty"}), - Id("Kind").Id("ToolKind").Tag(map[string]string{"json": "kind,omitempty"}), - Id("Locations").Index().Id("ToolCallLocation").Tag(map[string]string{"json": "locations,omitempty"}), - Id("RawInput").Any().Tag(map[string]string{"json": "rawInput,omitempty"}), - Id("RawOutput").Any().Tag(map[string]string{"json": "rawOutput,omitempty"}), - Id("Status").Id("ToolCallStatus").Tag(map[string]string{"json": "status,omitempty"}), - Id("Title").String().Tag(map[string]string{"json": "title"}), - Id("ToolCallId").Id("ToolCallId").Tag(map[string]string{"json": "toolCallId"}), - ) - f.Type().Id("SessionUpdateToolCallUpdate").Struct( - Id("Content").Index().Id("ToolCallContent").Tag(map[string]string{"json": "content,omitempty"}), - Id("Kind").Any().Tag(map[string]string{"json": "kind,omitempty"}), - Id("Locations").Index().Id("ToolCallLocation").Tag(map[string]string{"json": "locations,omitempty"}), - Id("RawInput").Any().Tag(map[string]string{"json": "rawInput,omitempty"}), - Id("RawOutput").Any().Tag(map[string]string{"json": "rawOutput,omitempty"}), - Id("Status").Any().Tag(map[string]string{"json": "status,omitempty"}), - Id("Title").Op("*").String().Tag(map[string]string{"json": "title,omitempty"}), - Id("ToolCallId").Id("ToolCallId").Tag(map[string]string{"json": "toolCallId"}), - ) - f.Type().Id("SessionUpdatePlan").Struct( - Id("Entries").Index().Id("PlanEntry").Tag(map[string]string{"json": "entries"}), - ) - f.Line() - // Holder - f.Type().Id("SessionUpdate").Struct( - Id("UserMessageChunk").Op("*").Id("SessionUpdateUserMessageChunk").Tag(map[string]string{"json": "-"}), - Id("AgentMessageChunk").Op("*").Id("SessionUpdateAgentMessageChunk").Tag(map[string]string{"json": "-"}), - Id("AgentThoughtChunk").Op("*").Id("SessionUpdateAgentThoughtChunk").Tag(map[string]string{"json": "-"}), - Id("ToolCall").Op("*").Id("SessionUpdateToolCall").Tag(map[string]string{"json": "-"}), - Id("ToolCallUpdate").Op("*").Id("SessionUpdateToolCallUpdate").Tag(map[string]string{"json": "-"}), - Id("Plan").Op("*").Id("SessionUpdatePlan").Tag(map[string]string{"json": "-"}), - ) - f.Func().Params(Id("s").Op("*").Id("SessionUpdate")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Var().Id("kind").String(), - If(List(Id("v"), Id("ok")).Op(":=").Id("m").Index(Lit("sessionUpdate")), Id("ok")).Block( - Qual("encoding/json", "Unmarshal").Call(Id("v"), Op("&").Id("kind")), - ), - Switch(Id("kind")).Block( - Case(Lit("user_message_chunk")).Block( - Var().Id("v").Id("SessionUpdateUserMessageChunk"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("UserMessageChunk").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("agent_message_chunk")).Block( - Var().Id("v").Id("SessionUpdateAgentMessageChunk"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("AgentMessageChunk").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("agent_thought_chunk")).Block( - Var().Id("v").Id("SessionUpdateAgentThoughtChunk"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("AgentThoughtChunk").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("tool_call")).Block( - Var().Id("v").Id("SessionUpdateToolCall"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("ToolCall").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("tool_call_update")).Block( - Var().Id("v").Id("SessionUpdateToolCallUpdate"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("ToolCallUpdate").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("plan")).Block( - Var().Id("v").Id("SessionUpdatePlan"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("Plan").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - ), - Return(Nil()), - ) - // MarshalJSON - f.Func().Params(Id("s").Id("SessionUpdate")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( - If(Id("s").Dot("UserMessageChunk").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("user_message_chunk"), - Lit("content"): Id("s").Dot("UserMessageChunk").Dot("Content"), - }))), - ), - If(Id("s").Dot("AgentMessageChunk").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("agent_message_chunk"), - Lit("content"): Id("s").Dot("AgentMessageChunk").Dot("Content"), - }))), - ), - If(Id("s").Dot("AgentThoughtChunk").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("agent_thought_chunk"), - Lit("content"): Id("s").Dot("AgentThoughtChunk").Dot("Content"), - }))), - ), - If(Id("s").Dot("ToolCall").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("tool_call"), - Lit("content"): Id("s").Dot("ToolCall").Dot("Content"), - Lit("kind"): Id("s").Dot("ToolCall").Dot("Kind"), - Lit("locations"): Id("s").Dot("ToolCall").Dot("Locations"), - Lit("rawInput"): Id("s").Dot("ToolCall").Dot("RawInput"), - Lit("rawOutput"): Id("s").Dot("ToolCall").Dot("RawOutput"), - Lit("status"): Id("s").Dot("ToolCall").Dot("Status"), - Lit("title"): Id("s").Dot("ToolCall").Dot("Title"), - Lit("toolCallId"): Id("s").Dot("ToolCall").Dot("ToolCallId"), - }))), - ), - If(Id("s").Dot("ToolCallUpdate").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("tool_call_update"), - Lit("content"): Id("s").Dot("ToolCallUpdate").Dot("Content"), - Lit("kind"): Id("s").Dot("ToolCallUpdate").Dot("Kind"), - Lit("locations"): Id("s").Dot("ToolCallUpdate").Dot("Locations"), - Lit("rawInput"): Id("s").Dot("ToolCallUpdate").Dot("RawInput"), - Lit("rawOutput"): Id("s").Dot("ToolCallUpdate").Dot("RawOutput"), - Lit("status"): Id("s").Dot("ToolCallUpdate").Dot("Status"), - Lit("title"): Id("s").Dot("ToolCallUpdate").Dot("Title"), - Lit("toolCallId"): Id("s").Dot("ToolCallUpdate").Dot("ToolCallId"), - }))), - ), - If(Id("s").Dot("Plan").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("plan"), - Lit("entries"): Id("s").Dot("Plan").Dot("Entries"), - }))), - ), - Return(Index().Byte().Values(), Nil()), - ) - f.Line() -} - -func primitiveJenType(t string) Code { - switch t { - case "string": - return String() - case "integer": - return Int() - case "number": - return Float64() - case "boolean": - return Bool() - default: - return Any() - } -} - -func jenTypeFor(d *Definition) Code { - if d == nil { - return Any() - } - if d.Ref != "" { - if strings.HasPrefix(d.Ref, "#/$defs/") { - return Id(d.Ref[len("#/$defs/"):]) - } - return Any() - } - if len(d.Enum) > 0 { - return String() - } - switch primaryType(d) { - case "string": - return String() - case "integer": - return Int() - case "number": - return Float64() - case "boolean": - return Bool() - case "array": - return Index().Add(jenTypeFor(d.Items)) - case "object": - if len(d.Properties) == 0 { - return Map(String()).Any() - } - return Map(String()).Any() - default: - if len(d.AnyOf) > 0 || len(d.OneOf) > 0 { - return Any() - } - return Any() - } -} - -// jenTypeForOptional maps unions that include null to pointer types where applicable. -func jenTypeForOptional(d *Definition) Code { - if d == nil { - return Any() - } - // Check anyOf/oneOf with exactly one non-null + null - list := d.AnyOf - if len(list) == 0 { - list = d.OneOf - } - if len(list) == 2 { - var nonNull *Definition - for _, e := range list { - if e == nil { - continue - } - if s, ok := e.Type.(string); ok && s == "null" { - continue - } - if e.Const != nil { - nn := *e - nn.Type = "string" - nonNull = &nn - } else { - nonNull = e - } - } - if nonNull != nil { - if nonNull.Ref != "" && strings.HasPrefix(nonNull.Ref, "#/$defs/") { - return Op("*").Id(nonNull.Ref[len("#/$defs/"):]) - } - switch primaryType(nonNull) { - case "string": - return Op("*").String() - case "integer": - return Op("*").Int() - case "number": - return Op("*").Float64() - case "boolean": - return Op("*").Bool() - } - } - } - return jenTypeFor(d) -} - -func isNullResponse(def *Definition) bool { - if def == nil { - return true - } - // type: null or oneOf with const null (unlikely here) - if s, ok := def.Type.(string); ok && s == "null" { - return true - } - return false -} - -func dispatchMethodNameForNotification(methodKey, typeName string) string { - switch methodKey { - case "session_update": - return "SessionUpdate" - case "session_cancel": - return "Cancel" - default: - // Fallback to type base without suffix - if strings.HasSuffix(typeName, "Notification") { - return strings.TrimSuffix(typeName, "Notification") - } - return typeName - } -} - -func writeDispatchJen(outDir string, schema *Schema, meta *Meta) error { - // Build method groups using schema + meta inference - groups := buildMethodGroups(schema, meta) - - // Agent handler method - fAgent := NewFile("acp") - fAgent.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") - // func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any, *RequestError) { switch method { ... } } - switchCases := []Code{} - // deterministic order via meta.AgentMethods - amKeys := make([]string, 0, len(meta.AgentMethods)) - for k := range meta.AgentMethods { - amKeys = append(amKeys, k) - } - sort.Strings(amKeys) - for _, k := range amKeys { - wire := meta.AgentMethods[k] - mi := groups["agent|"+wire] - if mi == nil { - continue - } - caseBody := []Code{} - if mi.Notif != "" { - // var p T; if err := json.Unmarshal(params, &p); err != nil { return nil, NewInvalidParams(...) } - caseBody = append(caseBody, - Var().Id("p").Id(mi.Notif), - If( - List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), - Id("err").Op("!=").Nil(), - ).Block( - Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), - ), - // Validate if available - If(List(Id("err")).Op(":=").Id("p").Dot("Validate").Call(), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), - ), - ) - // if err := a.agent.Call(p); err != nil { return nil, toReqErr(err) }; return nil, nil - callName := dispatchMethodNameForNotification(k, mi.Notif) - caseBody = append(caseBody, - If( - List(Id("err")).Op(":=").Id("a").Dot("agent").Dot(callName).Call(Id("p")), - Id("err").Op("!=").Nil(), - ).Block( - Return(Nil(), Id("toReqErr").Call(Id("err"))), - ), - Return(Nil(), Nil()), - ) - } else if mi.Req != "" { - respName := strings.TrimSuffix(mi.Req, "Request") + "Response" - caseBody = append(caseBody, - Var().Id("p").Id(mi.Req), - If( - List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), - Id("err").Op("!=").Nil(), - ).Block( - Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), - ), - If(List(Id("err")).Op(":=").Id("p").Dot("Validate").Call(), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), - ), - ) - methodName := strings.TrimSuffix(mi.Req, "Request") - // Optional: session/load lives on AgentLoader - if wire == "session/load" { - // Perform type assertion first, then branch - caseBody = append(caseBody, - List(Id("loader"), Id("ok")).Op(":=").Id("a").Dot("agent").Assert(Id("AgentLoader")), - If(Op("!").Id("ok")).Block( - Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))), - ), - ) - if isNullResponse(schema.Defs[respName]) { - caseBody = append(caseBody, - If( - List(Id("err")).Op(":=").Id("loader").Dot(methodName).Call(Id("p")), - Id("err").Op("!=").Nil(), - ).Block( - Return(Nil(), Id("toReqErr").Call(Id("err"))), - ), - Return(Nil(), Nil()), - ) - } else { - caseBody = append(caseBody, - List(Id("resp"), Id("err")).Op(":=").Id("loader").Dot(methodName).Call(Id("p")), - If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), - Return(Id("resp"), Nil()), - ) - } - } else if isDocsIgnoredMethod(schema, mi) { - // Undocumented/experimental agent methods require AgentExperimental - caseBody = append(caseBody, - List(Id("exp"), Id("ok")).Op(":=").Id("a").Dot("agent").Assert(Id("AgentExperimental")), - If(Op("!").Id("ok")).Block( - Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))), - ), - ) - if isNullResponse(schema.Defs[respName]) { - caseBody = append(caseBody, - If( - List(Id("err")).Op(":=").Id("exp").Dot(methodName).Call(Id("p")), - Id("err").Op("!=").Nil(), - ).Block( - Return(Nil(), Id("toReqErr").Call(Id("err"))), - ), - Return(Nil(), Nil()), - ) - } else { - caseBody = append(caseBody, - List(Id("resp"), Id("err")).Op(":=").Id("exp").Dot(methodName).Call(Id("p")), - If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), - Return(Id("resp"), Nil()), - ) - } - } else { - if isNullResponse(schema.Defs[respName]) { - caseBody = append(caseBody, - If( - List(Id("err")).Op(":=").Id("a").Dot("agent").Dot(methodName).Call(Id("p")), - Id("err").Op("!=").Nil(), - ).Block( - Return(Nil(), Id("toReqErr").Call(Id("err"))), - ), - Return(Nil(), Nil()), - ) - } else { - caseBody = append(caseBody, - List(Id("resp"), Id("err")).Op(":=").Id("a").Dot("agent").Dot(methodName).Call(Id("p")), - If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), - Return(Id("resp"), Nil()), - ) - } - } - } - if len(caseBody) > 0 { - switchCases = append(switchCases, Case(Id("AgentMethod"+toExportedConst(k))).Block(caseBody...)) - } - } - switchCases = append(switchCases, Default().Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))))) - fAgent.Func().Params(Id("a").Op("*").Id("AgentSideConnection")).Id("handle").Params( - Id("method").String(), - Id("params").Qual("encoding/json", "RawMessage"), - ).Params(Any(), Op("*").Id("RequestError")).Block( - Switch(Id("method")).Block(switchCases...), - ) - // After generating the handler, also append outbound wrappers for AgentSideConnection - // Build const name reverse lookup - agentConst := map[string]string{} - for k, v := range meta.AgentMethods { - agentConst[v] = "AgentMethod" + toExportedConst(k) - } - clientConst := map[string]string{} - for k, v := range meta.ClientMethods { - clientConst[v] = "ClientMethod" + toExportedConst(k) - } - // Agent outbound: methods the agent can call on the client (stable order) - cmKeys2 := make([]string, 0, len(meta.ClientMethods)) - for k := range meta.ClientMethods { - cmKeys2 = append(cmKeys2, k) - } - sort.Strings(cmKeys2) - for _, k := range cmKeys2 { - wire := meta.ClientMethods[k] - mi := groups["client|"+wire] - if mi == nil { - continue - } - constName := clientConst[mi.Method] - if constName == "" { - continue - } - if mi.Notif != "" { - name := strings.TrimSuffix(mi.Notif, "Notification") - switch mi.Method { - case "session/update": - name = "SessionUpdate" - case "session/cancel": - name = "Cancel" - } - fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(name).Params(Id("params").Id(mi.Notif)).Error(). - Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id(constName), Id("params")))) - } else if mi.Req != "" { - respName := strings.TrimSuffix(mi.Req, "Request") + "Response" - if isNullResponse(schema.Defs[respName]) { - fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). - Params(Id("params").Id(mi.Req)).Error(). - Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id(constName), Id("params")))) - } else { - fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). - Params(Id("params").Id(mi.Req)).Params(Id(respName), Error()). - Block( - List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id(constName), Id("params")), - Return(Id("resp"), Id("err")), - ) - } - } - } - var bufA bytes.Buffer - if err := fAgent.Render(&bufA); err != nil { - return err - } - if err := os.WriteFile(filepath.Join(outDir, "agent_gen.go"), bufA.Bytes(), 0o644); err != nil { - return err - } - - // Client handler method - fClient := NewFile("acp") - fClient.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") - cCases := []Code{} - cmKeys := make([]string, 0, len(meta.ClientMethods)) - for k := range meta.ClientMethods { - cmKeys = append(cmKeys, k) - } - sort.Strings(cmKeys) - for _, k := range cmKeys { - wire := meta.ClientMethods[k] - mi := groups["client|"+wire] - if mi == nil { - continue - } - body := []Code{} - if mi.Notif != "" { - body = append(body, - Var().Id("p").Id(mi.Notif), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), - ), - If(List(Id("err")).Op(":=").Id("p").Dot("Validate").Call(), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), - ), - ) - callName := dispatchMethodNameForNotification(k, mi.Notif) - body = append(body, - If(List(Id("err")).Op(":=").Id("c").Dot("client").Dot(callName).Call(Id("p")), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("toReqErr").Call(Id("err"))), - ), - Return(Nil(), Nil()), - ) - } else if mi.Req != "" { - respName := strings.TrimSuffix(mi.Req, "Request") + "Response" - body = append(body, - Var().Id("p").Id(mi.Req), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("params"), Op("&").Id("p")), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), - ), - If(List(Id("err")).Op(":=").Id("p").Dot("Validate").Call(), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("NewInvalidParams").Call(Map(String()).Id("any").Values(Dict{Lit("error"): Id("err").Dot("Error").Call()}))), - ), - ) - methodName := strings.TrimSuffix(mi.Req, "Request") - // Optional/experimental undocumented methods: require ClientTerminal for terminal/*, ClientExperimental otherwise - if isDocsIgnoredMethod(schema, mi) { - clientIface := "ClientExperimental" - if strings.HasPrefix(wire, "terminal/") { - clientIface = "ClientTerminal" - } - // Perform type assertion first, then branch - body = append(body, - List(Id("t"), Id("ok")).Op(":=").Id("c").Dot("client").Assert(Id(clientIface)), - If(Op("!").Id("ok")).Block( - Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))), - ), - ) - if isNullResponse(schema.Defs[respName]) { - body = append(body, - If(List(Id("err")).Op(":=").Id("t").Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("toReqErr").Call(Id("err"))), - ), - Return(Nil(), Nil()), - ) - } else { - body = append(body, - List(Id("resp"), Id("err")).Op(":=").Id("t").Dot(methodName).Call(Id("p")), - If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), - Return(Id("resp"), Nil()), - ) - } - } else { - if isNullResponse(schema.Defs[respName]) { - body = append(body, - If(List(Id("err")).Op(":=").Id("c").Dot("client").Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block( - Return(Nil(), Id("toReqErr").Call(Id("err"))), - ), - Return(Nil(), Nil()), - ) - } else { - body = append(body, - List(Id("resp"), Id("err")).Op(":=").Id("c").Dot("client").Dot(methodName).Call(Id("p")), - If(Id("err").Op("!=").Nil()).Block(Return(Nil(), Id("toReqErr").Call(Id("err")))), - Return(Id("resp"), Nil()), - ) - } - } - } - if len(body) > 0 { - cCases = append(cCases, Case(Id("ClientMethod"+toExportedConst(k))).Block(body...)) - } - } - cCases = append(cCases, Default().Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))))) - fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id("handle").Params( - Id("method").String(), Id("params").Qual("encoding/json", "RawMessage")).Params( - Any(), Op("*").Id("RequestError")).Block( - Switch(Id("method")).Block(cCases...), - ) - // After generating the handler, also append outbound wrappers for ClientSideConnection - // Client outbound: methods the client can call on the agent (stable order) - amKeys2 := make([]string, 0, len(meta.AgentMethods)) - for k := range meta.AgentMethods { - amKeys2 = append(amKeys2, k) - } - sort.Strings(amKeys2) - for _, k := range amKeys2 { - wire := meta.AgentMethods[k] - mi := groups["agent|"+wire] - if mi == nil { - continue - } - constName := agentConst[mi.Method] - if constName == "" { - continue - } - if mi.Notif != "" { - name := strings.TrimSuffix(mi.Notif, "Notification") - switch mi.Method { - case "session/update": - name = "SessionUpdate" - case "session/cancel": - name = "Cancel" - } - fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(name).Params(Id("params").Id(mi.Notif)).Error(). - Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id(constName), Id("params")))) - } else if mi.Req != "" { - respName := strings.TrimSuffix(mi.Req, "Request") + "Response" - if isNullResponse(schema.Defs[respName]) { - fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). - Params(Id("params").Id(mi.Req)).Error(). - Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id(constName), Id("params")))) - } else { - fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). - Params(Id("params").Id(mi.Req)).Params(Id(respName), Error()). - Block( - List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id(constName), Id("params")), - Return(Id("resp"), Id("err")), - ) - } - } - } - var bufC bytes.Buffer - if err := fClient.Render(&bufC); err != nil { - return err - } - if err := os.WriteFile(filepath.Join(outDir, "client_gen.go"), bufC.Bytes(), 0o644); err != nil { - return err - } - - // Clean up old split outbound files if present - _ = os.Remove(filepath.Join(outDir, "agent_outbound_gen.go")) - _ = os.Remove(filepath.Join(outDir, "client_outbound_gen.go")) - return nil -} - -func sanitizeComment(s string) string { - // Remove backticks and normalize newlines - s = strings.ReplaceAll(s, "`", "'") - lines := strings.Split(s, "\n") - for i := range lines { - lines[i] = strings.TrimSpace(lines[i]) - } - return strings.Join(lines, " ") -} - -func primaryType(d *Definition) string { - if d == nil || d.Type == nil { - return "" - } - switch v := d.Type.(type) { - case string: - return v - case []any: - // choose a non-null type if present - var first string - for _, e := range v { - if s, ok := e.(string); ok { - if first == "" { - first = s - } - if s != "null" { - return s - } - } - } - return first - default: - return "" - } -} - -func toExportedField(name string) string { - // Convert camelCase or snake_case to PascalCase; keep common acronyms minimal (ID -> Id) - // First, split on underscores - parts := strings.Split(name, "_") - if len(parts) == 1 { - // handle camelCase - parts = splitCamel(name) - } - for i := range parts { - parts[i] = titleWord(parts[i]) - } - return strings.Join(parts, "") -} - -func splitCamel(s string) []string { - var parts []string - last := 0 - for i := 1; i < len(s); i++ { - if isBoundary(s[i-1], s[i]) { - parts = append(parts, s[last:i]) - last = i - } - } - parts = append(parts, s[last:]) - return parts -} - -func isBoundary(prev, curr byte) bool { - return (prev >= 'a' && prev <= 'z' && curr >= 'A' && curr <= 'Z') || curr == '_' -} - -func toEnumConst(typeName, val string) string { - // Build CONST like - // Normalize value: replace non-alnum with underscores, split by underscores or spaces, title-case. - cleaned := make([]rune, 0, len(val)) - for _, r := range val { - if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { - cleaned = append(cleaned, r) - } else { - cleaned = append(cleaned, '_') - } - } - parts := strings.FieldsFunc(string(cleaned), func(r rune) bool { return r == '_' }) - for i := range parts { - parts[i] = titleWord(strings.ToLower(parts[i])) - } - return typeName + strings.Join(parts, "") -} - -func titleWord(s string) string { - if s == "" { - return s - } - r := []rune(s) - r[0] = unicode.ToUpper(r[0]) - for i := 1; i < len(r); i++ { - r[i] = unicode.ToLower(r[i]) - } - return string(r) -} diff --git a/go/constants.go b/go/constants_gen.go similarity index 100% rename from go/constants.go rename to go/constants_gen.go diff --git a/go/types.go b/go/types_gen.go similarity index 100% rename from go/types.go rename to go/types_gen.go From becba9a3fc31b1459eddae4b0bb819b076c12b77 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Sun, 31 Aug 2025 19:03:08 +0200 Subject: [PATCH 07/22] feat: add constructor helpers and compact examples for Go ACP implementation Change-Id: I337b07eea029a16481cf41869f9a327f9184c5fa Signed-off-by: Thomas Kosiewski --- go/cmd/generate/internal/emit/helpers.go | 112 ++++++++++++++++++++++ go/cmd/generate/main.go | 5 + go/example/agent/main.go | 52 +++++----- go/example/claude-code/main.go | 2 +- go/example/client/main.go | 5 +- go/example/gemini/main.go | 2 +- go/example_agent_test.go | 94 ++++++++++++++++++ go/example_client_test.go | 117 +++++++++++++++++++++++ go/example_gemini_test.go | 61 ++++++++++++ go/helpers_gen.go | 91 ++++++++++++++++++ 10 files changed, 508 insertions(+), 33 deletions(-) create mode 100644 go/cmd/generate/internal/emit/helpers.go create mode 100644 go/example_agent_test.go create mode 100644 go/example_client_test.go create mode 100644 go/example_gemini_test.go create mode 100644 go/helpers_gen.go diff --git a/go/cmd/generate/internal/emit/helpers.go b/go/cmd/generate/internal/emit/helpers.go new file mode 100644 index 0000000..d1a3429 --- /dev/null +++ b/go/cmd/generate/internal/emit/helpers.go @@ -0,0 +1,112 @@ +package emit + +import ( + "bytes" + "os" + "path/filepath" + + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/load" +) + +// WriteHelpersJen emits go/helpers_gen.go with small constructor helpers +// for common union variants and a Ptr generic helper. +func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { + f := NewFile("acp") + f.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") + + // Content helpers + f.Comment("TextBlock constructs a text content block.") + f.Func().Id("TextBlock").Params(Id("text").String()).Id("ContentBlock").Block( + Return(Id("ContentBlock").Values(Dict{ + Id("Type"): Lit("text"), + Id("Text"): Op("&").Id("TextContent").Values(Dict{Id("Text"): Id("text")}), + })), + ) + f.Line() + + f.Comment("ImageBlock constructs an inline image content block with base64-encoded data.") + f.Func().Id("ImageBlock").Params(Id("data").String(), Id("mimeType").String()).Id("ContentBlock").Block( + Return(Id("ContentBlock").Values(Dict{ + Id("Type"): Lit("image"), + Id("Image"): Op("&").Id("ImageContent").Values(Dict{Id("Data"): Id("data"), Id("MimeType"): Id("mimeType")}), + })), + ) + f.Line() + + f.Comment("AudioBlock constructs an inline audio content block with base64-encoded data.") + f.Func().Id("AudioBlock").Params(Id("data").String(), Id("mimeType").String()).Id("ContentBlock").Block( + Return(Id("ContentBlock").Values(Dict{ + Id("Type"): Lit("audio"), + Id("Audio"): Op("&").Id("AudioContent").Values(Dict{Id("Data"): Id("data"), Id("MimeType"): Id("mimeType")}), + })), + ) + f.Line() + + f.Comment("ResourceLinkBlock constructs a resource_link content block with a name and URI.") + f.Func().Id("ResourceLinkBlock").Params(Id("name").String(), Id("uri").String()).Id("ContentBlock").Block( + Return(Id("ContentBlock").Values(Dict{ + Id("Type"): Lit("resource_link"), + Id("ResourceLink"): Op("&").Id("ResourceLinkContent").Values(Dict{Id("Name"): Id("name"), Id("Uri"): Id("uri")}), + })), + ) + f.Line() + + f.Comment("ResourceBlock wraps an embedded resource as a content block.") + f.Func().Id("ResourceBlock").Params(Id("res").Id("EmbeddedResource")).Id("ContentBlock").Block( + Var().Id("r").Id("EmbeddedResource").Op("=").Id("res"), + Return(Id("ContentBlock").Values(Dict{ + Id("Type"): Lit("resource"), + Id("Resource"): Op("&").Id("r"), + })), + ) + f.Line() + + // ToolCall content helpers + f.Comment("ToolContent wraps a content block as tool-call content.") + f.Func().Id("ToolContent").Params(Id("block").Id("ContentBlock")).Id("ToolCallContent").Block( + Var().Id("b").Id("ContentBlock").Op("=").Id("block"), + Return(Id("ToolCallContent").Values(Dict{ + Id("Type"): Lit("content"), + Id("Content"): Op("&").Id("b"), + })), + ) + f.Line() + + f.Comment("ToolDiffContent constructs a diff tool-call content. If oldText is omitted, the field is left empty.") + f.Func().Id("ToolDiffContent").Params(Id("path").String(), Id("newText").String(), Id("oldText").Op("...").String()).Id("ToolCallContent").Block( + Var().Id("o").Op("*").String(), + If(Id("len").Call(Id("oldText")).Op(">").Lit(0)).Block( + Id("o").Op("=").Op("&").Id("oldText").Index(Lit(0)), + ), + Return(Id("ToolCallContent").Values(Dict{ + Id("Type"): Lit("diff"), + Id("Diff"): Op("&").Id("DiffContent").Values(Dict{ + Id("Path"): Id("path"), + Id("NewText"): Id("newText"), + Id("OldText"): Id("o"), + }), + })), + ) + f.Line() + + f.Comment("ToolTerminalRef constructs a terminal reference tool-call content.") + f.Func().Id("ToolTerminalRef").Params(Id("terminalId").String()).Id("ToolCallContent").Block( + Return(Id("ToolCallContent").Values(Dict{ + Id("Type"): Lit("terminal"), + Id("Terminal"): Op("&").Id("TerminalRef").Values(Dict{Id("TerminalId"): Id("terminalId")}), + })), + ) + f.Line() + + // Generic pointer helper + f.Comment("Ptr returns a pointer to v.") + f.Func().Id("Ptr").Types(Id("T").Any()).Params(Id("v").Id("T")).Op("*").Id("T").Block( + Return(Op("&").Id("v")), + ) + + var buf bytes.Buffer + if err := f.Render(&buf); err != nil { + return err + } + return os.WriteFile(filepath.Join(outDir, "helpers_gen.go"), buf.Bytes(), 0o644) +} diff --git a/go/cmd/generate/main.go b/go/cmd/generate/main.go index 1b01b97..7d29118 100644 --- a/go/cmd/generate/main.go +++ b/go/cmd/generate/main.go @@ -50,6 +50,11 @@ func main() { if err := emit.WriteDispatchJen(outDir, schema, meta); err != nil { panic(err) } + + // Emit helpers after types so they can reference generated structs. + if err := emit.WriteHelpersJen(outDir, schema, meta); err != nil { + panic(err) + } } func findRepoRoot() string { diff --git a/go/example/agent/main.go b/go/example/agent/main.go index ab7df17..9650f5b 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -91,10 +91,9 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { if err := a.conn.SessionUpdate(acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ - AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{Content: acp.ContentBlock{ - Type: "text", - Text: &acp.TextContent{Text: "I'll help you with that. Let me start by reading some files to understand the current situation."}, - }}, + AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ + Content: acp.TextBlock("I'll help you with that. Let me start by reading some files to understand the current situation."), + }, }, }); err != nil { return err @@ -127,10 +126,10 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { Update: acp.SessionUpdate{ToolCallUpdate: &acp.SessionUpdateToolCallUpdate{ ToolCallId: acp.ToolCallId("call_1"), Status: "completed", - Content: []acp.ToolCallContent{{ - Type: "content", - Content: &acp.ContentBlock{Type: "text", Text: &acp.TextContent{Text: "# My Project\n\nThis is a sample project..."}}, - }}, + Content: []acp.ToolCallContent{ + acp.ToolContent( + acp.TextBlock("# My Project\n\nThis is a sample project...")), + }, RawOutput: map[string]any{"content": "# My Project\n\nThis is a sample project..."}, }}, }); err != nil { @@ -143,10 +142,11 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { // more text if err := a.conn.SessionUpdate(acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{Content: acp.ContentBlock{ - Type: "text", - Text: &acp.TextContent{Text: " Now I understand the project structure. I need to make some changes to improve it."}, - }}}, + Update: acp.SessionUpdate{ + AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ + Content: acp.TextBlock(" Now I understand the project structure. I need to make some changes to improve it."), + }, + }, }); err != nil { return err } @@ -175,8 +175,8 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { ToolCall: acp.ToolCallUpdate{ ToolCallId: acp.ToolCallId("call_2"), Title: "Modifying critical configuration file", - Kind: ptr(acp.ToolKindEdit), - Status: ptr(acp.ToolCallStatusPending), + Kind: acp.Ptr(acp.ToolKindEdit), + Status: acp.Ptr(acp.ToolCallStatusPending), Locations: []acp.ToolCallLocation{{Path: "/home/user/project/config.json"}}, RawInput: map[string]any{"path": "/home/user/project/config.json", "content": "{\"database\": {\"host\": \"new-host\"}}"}, }, @@ -202,7 +202,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ToolCallUpdate: &acp.SessionUpdateToolCallUpdate{ ToolCallId: acp.ToolCallId("call_2"), - Status: ptr(acp.ToolCallStatusCompleted), + Status: acp.Ptr(acp.ToolCallStatusCompleted), RawOutput: map[string]any{"success": true, "message": "Configuration updated"}, }}, }); err != nil { @@ -213,10 +213,11 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } if err := a.conn.SessionUpdate(acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{Content: acp.ContentBlock{ - Type: "text", - Text: &acp.TextContent{Text: " Perfect! I've successfully updated the configuration. The changes have been applied."}, - }}}, + Update: acp.SessionUpdate{ + AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ + Content: acp.TextBlock(" Perfect! I've successfully updated the configuration. The changes have been applied."), + }, + }, }); err != nil { return err } @@ -226,10 +227,11 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } if err := a.conn.SessionUpdate(acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{Content: acp.ContentBlock{ - Type: "text", - Text: &acp.TextContent{Text: " I understand you prefer not to make that change. I'll skip the configuration update."}, - }}}, + Update: acp.SessionUpdate{ + AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ + Content: acp.TextBlock(" I understand you prefer not to make that change. I'll skip the configuration update."), + }, + }, }); err != nil { return err } @@ -259,10 +261,6 @@ func pause(ctx context.Context, d time.Duration) error { } } -func ptr[T any](t T) *T { - return &t -} - func main() { // Wire up stdio: write to stdout, read from stdin ag := newExampleAgent() diff --git a/go/example/claude-code/main.go b/go/example/claude-code/main.go index 3f28947..a6ebdbe 100644 --- a/go/example/claude-code/main.go +++ b/go/example/claude-code/main.go @@ -238,7 +238,7 @@ func main() { // Send prompt and wait for completion while streaming updates are printed via SessionUpdate if _, err := conn.Prompt(acp.PromptRequest{ SessionId: newSess.SessionId, - Prompt: []acp.ContentBlock{{Type: "text", Text: &acp.TextContent{Text: line}}}, + Prompt: []acp.ContentBlock{acp.TextBlock(line)}, }); err != nil { // If it's a JSON-RPC RequestError, surface more detail for troubleshooting if re, ok := err.(*acp.RequestError); ok { diff --git a/go/example/client/main.go b/go/example/client/main.go index b8cb317..5ca0d40 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -217,10 +217,7 @@ func main() { // Send prompt if _, err := conn.Prompt(acp.PromptRequest{ SessionId: newSess.SessionId, - Prompt: []acp.ContentBlock{{ - Type: "text", - Text: &acp.TextContent{Text: "Hello, agent!"}, - }}, + Prompt: []acp.ContentBlock{acp.TextBlock("Hello, agent!")}, }); err != nil { if re, ok := err.(*acp.RequestError); ok { if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index 982b04d..adbacb6 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -258,7 +258,7 @@ func main() { // Send prompt and wait for completion while streaming updates are printed via SessionUpdate if _, err := conn.Prompt(acp.PromptRequest{ SessionId: newSess.SessionId, - Prompt: []acp.ContentBlock{{Type: "text", Text: &acp.TextContent{Text: line}}}, + Prompt: []acp.ContentBlock{acp.TextBlock(line)}, }); err != nil { // If it's a JSON-RPC RequestError, surface more detail for troubleshooting if re, ok := err.(*acp.RequestError); ok { diff --git a/go/example_agent_test.go b/go/example_agent_test.go new file mode 100644 index 0000000..0f28bc9 --- /dev/null +++ b/go/example_agent_test.go @@ -0,0 +1,94 @@ +package acp + +import ( + "os" +) + +// agentExample mirrors the go/example/agent flow in a compact form. +// It streams a short message, demonstrates a tool call + permission, +// then ends the turn. +type agentExample struct{ conn *AgentSideConnection } + +func (a *agentExample) SetAgentConnection(c *AgentSideConnection) { a.conn = c } + +func (agentExample) Authenticate(AuthenticateRequest) error { return nil } +func (agentExample) Initialize(InitializeRequest) (InitializeResponse, error) { + return InitializeResponse{ + ProtocolVersion: ProtocolVersionNumber, + AgentCapabilities: AgentCapabilities{LoadSession: false}, + }, nil +} +func (agentExample) Cancel(CancelNotification) error { return nil } +func (agentExample) NewSession(NewSessionRequest) (NewSessionResponse, error) { + return NewSessionResponse{SessionId: SessionId("sess_demo")}, nil +} + +func (a *agentExample) Prompt(p PromptRequest) (PromptResponse, error) { + // Stream an initial agent message. + _ = a.conn.SessionUpdate(SessionNotification{ + SessionId: p.SessionId, + Update: SessionUpdate{ + AgentMessageChunk: &SessionUpdateAgentMessageChunk{ + Content: TextBlock("I'll help you with that."), + }, + }, + }) + + // Announce a tool call. + _ = a.conn.SessionUpdate(SessionNotification{ + SessionId: p.SessionId, + Update: SessionUpdate{ToolCall: &SessionUpdateToolCall{ + ToolCallId: ToolCallId("call_1"), + Title: "Modifying configuration", + Kind: ToolKindEdit, + Status: ToolCallStatusPending, + Locations: []ToolCallLocation{{Path: "/project/config.json"}}, + RawInput: map[string]any{"path": "/project/config.json"}, + }}, + }) + + // Ask the client for permission to proceed with the change. + resp, _ := a.conn.RequestPermission(RequestPermissionRequest{ + SessionId: p.SessionId, + ToolCall: ToolCallUpdate{ + ToolCallId: ToolCallId("call_1"), + Title: "Modifying configuration", + Kind: Ptr(ToolKindEdit), + Status: Ptr(ToolCallStatusPending), + Locations: []ToolCallLocation{{Path: "/project/config.json"}}, + RawInput: map[string]any{"path": "/project/config.json"}, + }, + Options: []PermissionOption{ + {Kind: PermissionOptionKindAllowOnce, Name: "Allow", OptionId: PermissionOptionId("allow")}, + {Kind: PermissionOptionKindRejectOnce, Name: "Reject", OptionId: PermissionOptionId("reject")}, + }, + }) + + if resp.Outcome.Selected != nil && string(resp.Outcome.Selected.OptionId) == "allow" { + // Mark tool call completed and stream a final message. + _ = a.conn.SessionUpdate(SessionNotification{ + SessionId: p.SessionId, + Update: SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ + ToolCallId: ToolCallId("call_1"), + Status: ToolCallStatusCompleted, + RawOutput: map[string]any{"success": true}, + }}, + }) + _ = a.conn.SessionUpdate(SessionNotification{ + SessionId: p.SessionId, + Update: SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: TextBlock("Done.")}}, + }) + } + + return PromptResponse{StopReason: StopReasonEndTurn}, nil +} + +// Example_agent wires the Agent to stdio so an external client +// can connect via this process' stdin/stdout. +func Example_agent() { + ag := &agentExample{} + asc := NewAgentSideConnection(ag, os.Stdout, os.Stdin) + ag.SetAgentConnection(asc) + // In a real program, block until the peer disconnects: + // <-asc.Done() +} diff --git a/go/example_client_test.go b/go/example_client_test.go new file mode 100644 index 0000000..5fe56f5 --- /dev/null +++ b/go/example_client_test.go @@ -0,0 +1,117 @@ +package acp + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// clientExample mirrors go/example/client in a compact form: prints +// streamed updates, handles simple file ops, and picks the first +// permission option. +type clientExample struct{} + +func (clientExample) RequestPermission(p RequestPermissionRequest) (RequestPermissionResponse, error) { + if len(p.Options) == 0 { + return RequestPermissionResponse{ + Outcome: RequestPermissionOutcome{ + Cancelled: &RequestPermissionOutcomeCancelled{}, + }, + }, nil + } + return RequestPermissionResponse{ + Outcome: RequestPermissionOutcome{ + Selected: &RequestPermissionOutcomeSelected{OptionId: p.Options[0].OptionId}, + }, + }, nil +} + +func (clientExample) SessionUpdate(n SessionNotification) error { + u := n.Update + switch { + case u.AgentMessageChunk != nil: + c := u.AgentMessageChunk.Content + if c.Type == "text" && c.Text != nil { + fmt.Print(c.Text.Text) + } else { + fmt.Println("[", c.Type, "]") + } + case u.ToolCall != nil: + fmt.Printf("\n[tool] %s (%s)\n", u.ToolCall.Title, u.ToolCall.Status) + case u.ToolCallUpdate != nil: + fmt.Printf("\n[tool] %s -> %v\n", u.ToolCallUpdate.ToolCallId, u.ToolCallUpdate.Status) + } + return nil +} + +func (clientExample) WriteTextFile(p WriteTextFileRequest) error { + if !filepath.IsAbs(p.Path) { + return fmt.Errorf("path must be absolute: %s", p.Path) + } + if dir := filepath.Dir(p.Path); dir != "" { + _ = os.MkdirAll(dir, 0o755) + } + return os.WriteFile(p.Path, []byte(p.Content), 0o644) +} + +func (clientExample) ReadTextFile(p ReadTextFileRequest) (ReadTextFileResponse, error) { + if !filepath.IsAbs(p.Path) { + return ReadTextFileResponse{}, fmt.Errorf("path must be absolute: %s", p.Path) + } + b, err := os.ReadFile(p.Path) + if err != nil { + return ReadTextFileResponse{}, err + } + content := string(b) + if p.Line > 0 || p.Limit > 0 { + lines := strings.Split(content, "\n") + start := 0 + if p.Line > 0 { + if p.Line-1 > 0 { + start = p.Line - 1 + } + if start > len(lines) { + start = len(lines) + } + } + end := len(lines) + if p.Limit > 0 && start+p.Limit < end { + end = start + p.Limit + } + content = strings.Join(lines[start:end], "\n") + } + return ReadTextFileResponse{Content: content}, nil +} + +// Example_client launches the Go agent example, negotiates protocol, +// opens a session, and sends a simple prompt. +func Example_client() { + cmd := exec.Command("go", "run", "./example/agent") + stdin, _ := cmd.StdinPipe() + stdout, _ := cmd.StdoutPipe() + _ = cmd.Start() + + conn := NewClientSideConnection(clientExample{}, stdin, stdout) + _, _ = conn.Initialize(InitializeRequest{ + ProtocolVersion: ProtocolVersionNumber, + ClientCapabilities: ClientCapabilities{ + Fs: FileSystemCapability{ + ReadTextFile: true, + WriteTextFile: true, + }, + Terminal: true, + }, + }) + sess, _ := conn.NewSession(NewSessionRequest{ + Cwd: "/", + McpServers: []McpServer{}, + }) + _, _ = conn.Prompt(PromptRequest{ + SessionId: sess.SessionId, + Prompt: []ContentBlock{TextBlock("Hello, agent!")}, + }) + + _ = cmd.Process.Kill() +} diff --git a/go/example_gemini_test.go b/go/example_gemini_test.go new file mode 100644 index 0000000..8487c3f --- /dev/null +++ b/go/example_gemini_test.go @@ -0,0 +1,61 @@ +package acp + +import ( + "fmt" + "os/exec" +) + +// geminiClient mirrors go/example/gemini in brief: prints text chunks and +// selects the first permission option. File ops are no-ops here. +type geminiClient struct{} + +func (geminiClient) RequestPermission(p RequestPermissionRequest) (RequestPermissionResponse, error) { + if len(p.Options) == 0 { + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Cancelled: &RequestPermissionOutcomeCancelled{}}}, nil + } + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: p.Options[0].OptionId}}}, nil +} + +func (geminiClient) SessionUpdate(n SessionNotification) error { + if n.Update.AgentMessageChunk != nil { + c := n.Update.AgentMessageChunk.Content + if c.Type == "text" && c.Text != nil { + fmt.Print(c.Text.Text) + } + } + return nil +} + +func (geminiClient) ReadTextFile(ReadTextFileRequest) (ReadTextFileResponse, error) { + return ReadTextFileResponse{}, nil +} +func (geminiClient) WriteTextFile(WriteTextFileRequest) error { return nil } + +// Example_gemini connects to a Gemini CLI speaking ACP over stdio, +// then initializes, opens a session, and sends a prompt. +func Example_gemini() { + cmd := exec.Command("gemini", "--experimental-acp") + stdin, _ := cmd.StdinPipe() + stdout, _ := cmd.StdoutPipe() + _ = cmd.Start() + + conn := NewClientSideConnection(geminiClient{}, stdin, stdout) + _, _ = conn.Initialize(InitializeRequest{ + ProtocolVersion: ProtocolVersionNumber, + ClientCapabilities: ClientCapabilities{ + Fs: FileSystemCapability{ + ReadTextFile: true, + WriteTextFile: true, + }, + Terminal: true, + }, + }) + sess, _ := conn.NewSession(NewSessionRequest{ + Cwd: "/", + McpServers: []McpServer{}, + }) + _, _ = conn.Prompt(PromptRequest{ + SessionId: sess.SessionId, + Prompt: []ContentBlock{TextBlock("list files")}, + }) +} diff --git a/go/helpers_gen.go b/go/helpers_gen.go new file mode 100644 index 0000000..51d45f2 --- /dev/null +++ b/go/helpers_gen.go @@ -0,0 +1,91 @@ +// Code generated by acp-go-generator; DO NOT EDIT. + +package acp + +// TextBlock constructs a text content block. +func TextBlock(text string) ContentBlock { + return ContentBlock{ + Text: &TextContent{Text: text}, + Type: "text", + } +} + +// ImageBlock constructs an inline image content block with base64-encoded data. +func ImageBlock(data string, mimeType string) ContentBlock { + return ContentBlock{ + Image: &ImageContent{ + Data: data, + MimeType: mimeType, + }, + Type: "image", + } +} + +// AudioBlock constructs an inline audio content block with base64-encoded data. +func AudioBlock(data string, mimeType string) ContentBlock { + return ContentBlock{ + Audio: &AudioContent{ + Data: data, + MimeType: mimeType, + }, + Type: "audio", + } +} + +// ResourceLinkBlock constructs a resource_link content block with a name and URI. +func ResourceLinkBlock(name string, uri string) ContentBlock { + return ContentBlock{ + ResourceLink: &ResourceLinkContent{ + Name: name, + Uri: uri, + }, + Type: "resource_link", + } +} + +// ResourceBlock wraps an embedded resource as a content block. +func ResourceBlock(res EmbeddedResource) ContentBlock { + var r EmbeddedResource = res + return ContentBlock{ + Resource: &r, + Type: "resource", + } +} + +// ToolContent wraps a content block as tool-call content. +func ToolContent(block ContentBlock) ToolCallContent { + var b ContentBlock = block + return ToolCallContent{ + Content: &b, + Type: "content", + } +} + +// ToolDiffContent constructs a diff tool-call content. If oldText is omitted, the field is left empty. +func ToolDiffContent(path string, newText string, oldText ...string) ToolCallContent { + var o *string + if len(oldText) > 0 { + o = &oldText[0] + } + return ToolCallContent{ + Diff: &DiffContent{ + NewText: newText, + OldText: o, + Path: path, + }, + Type: "diff", + } +} + +// ToolTerminalRef constructs a terminal reference tool-call content. +func ToolTerminalRef(terminalId string) ToolCallContent { + return ToolCallContent{ + Terminal: &TerminalRef{TerminalId: terminalId}, + Type: "terminal", + } +} + +// Ptr returns a pointer to v. +func Ptr[T any](v T) *T { + return &v +} From 1b1b7dcde870e7198ce33ab47e8494adcd2cd228 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Sun, 31 Aug 2025 19:17:35 +0200 Subject: [PATCH 08/22] docs: add Go library README with installation and usage guide Change-Id: Idbc02da87f6925b2fd847ffe1a04fcdda33b2162 Signed-off-by: Thomas Kosiewski --- go/README.md | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 go/README.md diff --git a/go/README.md b/go/README.md new file mode 100644 index 0000000..c8e2d7e --- /dev/null +++ b/go/README.md @@ -0,0 +1,74 @@ + + Agent Client Protocol + + +# ACP Go Library + +The official Go implementation of the Agent Client Protocol (ACP) — a standardized communication protocol between code editors and AI‑powered coding agents. + +Learn more at + +## Installation + +```bash +go get github.com/zed-industries/agent-client-protocol/go@latest +``` + +## Get Started + +### Understand the Protocol + +Start by reading the [official ACP documentation](https://agentclientprotocol.com) to understand the core concepts and protocol specification. + +### Try the Examples + +The [examples directory](https://github.com/zed-industries/agent-client-protocol/tree/main/go/example) contains simple implementations of both Agents and Clients in Go. You can run them from your terminal or connect to external ACP agents. + +- Run the example Agent: + - `cd go && go run ./example/agent` +- Run the example Client (connects to the example Agent): + - `cd go && go run ./example/client` +- Connect to the Gemini CLI (ACP mode): + - `cd go && go run ./example/gemini -yolo` + - Optional flags: `-model`, `-sandbox`, `-debug`, `-gemini /path/to/gemini` +- Connect to Claude Code (via npx): + - `cd go && go run ./example/claude-code -yolo` + +### Explore the API + +Browse the Go package docs on pkg.go.dev for detailed API documentation: + +- + +If you're building an [Agent](https://agentclientprotocol.com/protocol/overview#agent): + +- Implement the `acp.Agent` interface (and optionally `acp.AgentLoader` for `session/load`). +- Create a connection with `acp.NewAgentSideConnection(agent, os.Stdout, os.Stdin)`. +- Send updates and make client requests using the returned connection. + +If you're building a [Client](https://agentclientprotocol.com/protocol/overview#client): + +- Implement the `acp.Client` interface (and optionally `acp.ClientTerminal` for terminal features). +- Launch or connect to your Agent process (stdio), then create a connection with `acp.NewClientSideConnection(client, stdin, stdout)`. +- Call `Initialize`, `NewSession`, and `Prompt` to run a turn and stream updates. + +Helper constructors are provided to reduce boilerplate when working with union types: + +- Content blocks: `acp.TextBlock`, `acp.ImageBlock`, `acp.AudioBlock`, `acp.ResourceLinkBlock`, `acp.ResourceBlock`. +- Tool content: `acp.ToolContent`, `acp.ToolDiffContent`, `acp.ToolTerminalRef`. +- Utility: `acp.Ptr[T]` for pointer fields in request/update structs. + +### Study a Production Implementation + +For a complete, production‑ready integration, see the [Gemini CLI Agent](https://github.com/google-gemini/gemini-cli) which exposes an ACP interface. The Go example client `go/example/gemini` demonstrates connecting to it via stdio. + +## Resources + +- [Go package docs](https://pkg.go.dev/github.com/zed-industries/agent-client-protocol/go) +- [Examples (Go)](https://github.com/zed-industries/agent-client-protocol/tree/main/go/example) +- [Protocol Documentation](https://agentclientprotocol.com) +- [GitHub Repository](https://github.com/zed-industries/agent-client-protocol) + +## Contributing + +See the main [repository](https://github.com/zed-industries/agent-client-protocol) for contribution guidelines. From 291245d8b33b200479c795113c76a0a4329580fc Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Sun, 31 Aug 2025 19:36:11 +0200 Subject: [PATCH 09/22] docs: add Go library to documentation with package info and examples Change-Id: If82530aca4ae621a6d8200cebbdd50e144f57f24 Signed-off-by: Thomas Kosiewski --- README.md | 1 + docs/docs.json | 2 +- docs/libraries/go.mdx | 33 +++++++++++++++++++++++++++++++++ go/example/agent/main.go | 14 ++++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 docs/libraries/go.mdx diff --git a/README.md b/README.md index 3698161..00fad22 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Learn more at [agentclientprotocol.com](https://agentclientprotocol.com/). - **Rust**: [`agent-client-protocol`](https://crates.io/crates/agent-client-protocol) - See [example_agent.rs](./rust/example_agent.rs) and [example_client.rs](./rust/example_client.rs) - **TypeScript**: [`@zed-industries/agent-client-protocol`](https://www.npmjs.com/package/@zed-industries/agent-client-protocol) - See [examples/](./typescript/examples/) +- **Go**: [`github.com/zed-industries/agent-client-protocol/go`](https://pkg.go.dev/github.com/zed-industries/agent-client-protocol/go) - See [example/](./go/example/) and the [Go README](./go/README.md) - **JSON Schema**: [schema.json](./schema/schema.json) ## Contributing diff --git a/docs/docs.json b/docs/docs.json index 52e202b..d8503e1 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -62,7 +62,7 @@ }, { "group": "Libraries", - "pages": ["libraries/typescript", "libraries/rust"] + "pages": ["libraries/typescript", "libraries/rust", "libraries/go"] }, { "group": "Community", diff --git a/docs/libraries/go.mdx b/docs/libraries/go.mdx new file mode 100644 index 0000000..27e8bf8 --- /dev/null +++ b/docs/libraries/go.mdx @@ -0,0 +1,33 @@ +--- +title: "Go" +description: "Go library for the Agent Client Protocol" +--- + +The [`github.com/zed-industries/agent-client-protocol/go`](https://pkg.go.dev/github.com/zed-industries/agent-client-protocol/go) +package provides implementations of both sides of the Agent Client Protocol that +you can use to build your own agent server or client. + +To get started, add the module to your project: + +```bash +go get github.com/zed-industries/agent-client-protocol/go@latest +``` + +Depending on what kind of tool you're building, you'll create either the +[AgentSideConnection](https://pkg.go.dev/github.com/zed-industries/agent-client-protocol/go#NewAgentSideConnection) +or the +[ClientSideConnection](https://pkg.go.dev/github.com/zed-industries/agent-client-protocol/go#NewClientSideConnection) +and implement the corresponding interfaces (`Agent`, `Client`). + +You can find example implementations of both sides in the +[main repository](https://github.com/zed-industries/agent-client-protocol/tree/main/go/example). +These can be run from your terminal or connected to external ACP agents, making +them great starting points for your own integration! + +Browse the Go package docs on +[pkg.go.dev](https://pkg.go.dev/github.com/zed-industries/agent-client-protocol/go) +for detailed API documentation. + +For a complete, production-ready implementation of an ACP agent, see the +[Gemini CLI](https://github.com/google-gemini/gemini-cli) which exposes an ACP +interface. The Go example client demonstrates connecting to it via stdio. diff --git a/go/example/agent/main.go b/go/example/agent/main.go index 9650f5b..a046732 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -87,6 +87,20 @@ func (a *exampleAgent) Prompt(params acp.PromptRequest) (acp.PromptResponse, err } func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { + // disclaimer: stream a demo notice so clients see it's the example agent + if err := a.conn.SessionUpdate(acp.SessionNotification{ + SessionId: acp.SessionId(sid), + Update: acp.SessionUpdate{ + AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ + Content: acp.TextBlock("ACP Go Example Agent — demo only (no AI model)."), + }, + }, + }); err != nil { + return err + } + if err := pause(ctx, 250*time.Millisecond); err != nil { + return err + } // initial message chunk if err := a.conn.SessionUpdate(acp.SessionNotification{ SessionId: acp.SessionId(sid), From e8062b1cfc2e3226b540144ae09a9ec3e87adebb Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Sun, 31 Aug 2025 23:50:10 +0200 Subject: [PATCH 10/22] feat: add context parameter to all ACP interface methods and handle cancellation Change-Id: Ic242f8ab12e3760e7ef67b70385f8db1cf5a0262 Signed-off-by: Thomas Kosiewski --- go/acp_test.go | 307 ++++++++++++------ go/agent.go | 6 + go/agent_gen.go | 69 ++-- go/client_gen.go | 52 +-- go/cmd/generate/internal/emit/dispatch.go | 85 +++-- .../internal/emit/dispatch_helpers.go | 6 +- go/cmd/generate/internal/emit/types.go | 12 +- go/connection.go | 33 +- go/example/agent/main.go | 32 +- go/example/claude-code/main.go | 31 +- go/example/client/main.go | 30 +- go/example/gemini/main.go | 35 +- go/example_agent_test.go | 21 +- go/example_client_test.go | 16 +- go/example_gemini_test.go | 16 +- go/types_gen.go | 29 +- 16 files changed, 513 insertions(+), 267 deletions(-) diff --git a/go/acp_test.go b/go/acp_test.go index d2558bd..987d56c 100644 --- a/go/acp_test.go +++ b/go/acp_test.go @@ -1,6 +1,7 @@ package acp import ( + "context" "io" "slices" "sync" @@ -9,49 +10,49 @@ import ( ) type clientFuncs struct { - WriteTextFileFunc func(WriteTextFileRequest) error - ReadTextFileFunc func(ReadTextFileRequest) (ReadTextFileResponse, error) - RequestPermissionFunc func(RequestPermissionRequest) (RequestPermissionResponse, error) - SessionUpdateFunc func(SessionNotification) error + WriteTextFileFunc func(context.Context, WriteTextFileRequest) error + ReadTextFileFunc func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) + RequestPermissionFunc func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) + SessionUpdateFunc func(context.Context, SessionNotification) error } var _ Client = (*clientFuncs)(nil) -func (c clientFuncs) WriteTextFile(p WriteTextFileRequest) error { +func (c clientFuncs) WriteTextFile(ctx context.Context, p WriteTextFileRequest) error { if c.WriteTextFileFunc != nil { - return c.WriteTextFileFunc(p) + return c.WriteTextFileFunc(ctx, p) } return nil } -func (c clientFuncs) ReadTextFile(p ReadTextFileRequest) (ReadTextFileResponse, error) { +func (c clientFuncs) ReadTextFile(ctx context.Context, p ReadTextFileRequest) (ReadTextFileResponse, error) { if c.ReadTextFileFunc != nil { - return c.ReadTextFileFunc(p) + return c.ReadTextFileFunc(ctx, p) } return ReadTextFileResponse{}, nil } -func (c clientFuncs) RequestPermission(p RequestPermissionRequest) (RequestPermissionResponse, error) { +func (c clientFuncs) RequestPermission(ctx context.Context, p RequestPermissionRequest) (RequestPermissionResponse, error) { if c.RequestPermissionFunc != nil { - return c.RequestPermissionFunc(p) + return c.RequestPermissionFunc(ctx, p) } return RequestPermissionResponse{}, nil } -func (c clientFuncs) SessionUpdate(n SessionNotification) error { +func (c clientFuncs) SessionUpdate(ctx context.Context, n SessionNotification) error { if c.SessionUpdateFunc != nil { - return c.SessionUpdateFunc(n) + return c.SessionUpdateFunc(ctx, n) } return nil } type agentFuncs struct { - InitializeFunc func(InitializeRequest) (InitializeResponse, error) - NewSessionFunc func(NewSessionRequest) (NewSessionResponse, error) - LoadSessionFunc func(LoadSessionRequest) error - AuthenticateFunc func(AuthenticateRequest) error - PromptFunc func(PromptRequest) (PromptResponse, error) - CancelFunc func(CancelNotification) error + InitializeFunc func(context.Context, InitializeRequest) (InitializeResponse, error) + NewSessionFunc func(context.Context, NewSessionRequest) (NewSessionResponse, error) + LoadSessionFunc func(context.Context, LoadSessionRequest) error + AuthenticateFunc func(context.Context, AuthenticateRequest) error + PromptFunc func(context.Context, PromptRequest) (PromptResponse, error) + CancelFunc func(context.Context, CancelNotification) error } var ( @@ -59,85 +60,92 @@ var ( _ AgentLoader = (*agentFuncs)(nil) ) -func (a agentFuncs) Initialize(p InitializeRequest) (InitializeResponse, error) { +func (a agentFuncs) Initialize(ctx context.Context, p InitializeRequest) (InitializeResponse, error) { if a.InitializeFunc != nil { - return a.InitializeFunc(p) + return a.InitializeFunc(ctx, p) } return InitializeResponse{}, nil } -func (a agentFuncs) NewSession(p NewSessionRequest) (NewSessionResponse, error) { +func (a agentFuncs) NewSession(ctx context.Context, p NewSessionRequest) (NewSessionResponse, error) { if a.NewSessionFunc != nil { - return a.NewSessionFunc(p) + return a.NewSessionFunc(ctx, p) } return NewSessionResponse{}, nil } -func (a agentFuncs) LoadSession(p LoadSessionRequest) error { +func (a agentFuncs) LoadSession(ctx context.Context, p LoadSessionRequest) error { if a.LoadSessionFunc != nil { - return a.LoadSessionFunc(p) + return a.LoadSessionFunc(ctx, p) } return nil } -func (a agentFuncs) Authenticate(p AuthenticateRequest) error { +func (a agentFuncs) Authenticate(ctx context.Context, p AuthenticateRequest) error { if a.AuthenticateFunc != nil { - return a.AuthenticateFunc(p) + return a.AuthenticateFunc(ctx, p) } return nil } -func (a agentFuncs) Prompt(p PromptRequest) (PromptResponse, error) { +func (a agentFuncs) Prompt(ctx context.Context, p PromptRequest) (PromptResponse, error) { if a.PromptFunc != nil { - return a.PromptFunc(p) + return a.PromptFunc(ctx, p) } return PromptResponse{}, nil } -func (a agentFuncs) Cancel(n CancelNotification) error { +func (a agentFuncs) Cancel(ctx context.Context, n CancelNotification) error { if a.CancelFunc != nil { - return a.CancelFunc(n) + return a.CancelFunc(ctx, n) } return nil } // Test bidirectional error handling similar to typescript/acp.test.ts func TestConnectionHandlesErrorsBidirectional(t *testing.T) { + ctx := context.Background() c2aR, c2aW := io.Pipe() a2cR, a2cW := io.Pipe() c := NewClientSideConnection(clientFuncs{ - WriteTextFileFunc: func(WriteTextFileRequest) error { return &RequestError{Code: -32603, Message: "Write failed"} }, - ReadTextFileFunc: func(ReadTextFileRequest) (ReadTextFileResponse, error) { + WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { + return &RequestError{Code: -32603, Message: "Write failed"} + }, + ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) { return ReadTextFileResponse{}, &RequestError{Code: -32603, Message: "Read failed"} }, - RequestPermissionFunc: func(RequestPermissionRequest) (RequestPermissionResponse, error) { + RequestPermissionFunc: func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) { return RequestPermissionResponse{}, &RequestError{Code: -32603, Message: "Permission denied"} }, - SessionUpdateFunc: func(SessionNotification) error { return nil }, + SessionUpdateFunc: func(context.Context, SessionNotification) error { return nil }, }, c2aW, a2cR) agentConn := NewAgentSideConnection(agentFuncs{ - InitializeFunc: func(InitializeRequest) (InitializeResponse, error) { + InitializeFunc: func(context.Context, InitializeRequest) (InitializeResponse, error) { return InitializeResponse{}, &RequestError{Code: -32603, Message: "Failed to initialize"} }, - NewSessionFunc: func(NewSessionRequest) (NewSessionResponse, error) { + NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{}, &RequestError{Code: -32603, Message: "Failed to create session"} }, - LoadSessionFunc: func(LoadSessionRequest) error { return &RequestError{Code: -32603, Message: "Failed to load session"} }, - AuthenticateFunc: func(AuthenticateRequest) error { return &RequestError{Code: -32603, Message: "Authentication failed"} }, - PromptFunc: func(PromptRequest) (PromptResponse, error) { + LoadSessionFunc: func(context.Context, LoadSessionRequest) error { + return &RequestError{Code: -32603, Message: "Failed to load session"} + }, + AuthenticateFunc: func(context.Context, AuthenticateRequest) error { + return &RequestError{Code: -32603, Message: "Authentication failed"} + }, + PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) { return PromptResponse{}, &RequestError{Code: -32603, Message: "Prompt failed"} }, - CancelFunc: func(CancelNotification) error { return nil }, + CancelFunc: func(context.Context, CancelNotification) error { return nil }, }, a2cW, c2aR) // Client->Agent direction: expect error - if err := agentConn.WriteTextFile(WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err == nil { + if err := agentConn.WriteTextFile(ctx, WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err == nil { t.Fatalf("expected error for writeTextFile, got nil") } // Agent->Client direction: expect error - if _, err := c.NewSession(NewSessionRequest{Cwd: "/test", McpServers: []McpServer{}}); err == nil { + if _, err := c.NewSession(ctx, NewSessionRequest{Cwd: "/test", McpServers: []McpServer{}}); err == nil { t.Fatalf("expected error for newSession, got nil") } } @@ -151,32 +159,34 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) { requestCount := 0 _ = NewClientSideConnection(clientFuncs{ - WriteTextFileFunc: func(WriteTextFileRequest) error { + WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { mu.Lock() requestCount++ mu.Unlock() time.Sleep(40 * time.Millisecond) return nil }, - ReadTextFileFunc: func(p ReadTextFileRequest) (ReadTextFileResponse, error) { - return ReadTextFileResponse{Content: "Content of " + p.Path}, nil + ReadTextFileFunc: func(_ context.Context, req ReadTextFileRequest) (ReadTextFileResponse, error) { + return ReadTextFileResponse{Content: "Content of " + req.Path}, nil }, - RequestPermissionFunc: func(RequestPermissionRequest) (RequestPermissionResponse, error) { + RequestPermissionFunc: func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) { return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil }, - SessionUpdateFunc: func(SessionNotification) error { return nil }, + SessionUpdateFunc: func(context.Context, SessionNotification) error { return nil }, }, c2aW, a2cR) agentConn := NewAgentSideConnection(agentFuncs{ - InitializeFunc: func(InitializeRequest) (InitializeResponse, error) { + InitializeFunc: func(context.Context, InitializeRequest) (InitializeResponse, error) { return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil }, - NewSessionFunc: func(NewSessionRequest) (NewSessionResponse, error) { + NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{SessionId: "test-session"}, nil }, - LoadSessionFunc: func(LoadSessionRequest) error { return nil }, - AuthenticateFunc: func(AuthenticateRequest) error { return nil }, - PromptFunc: func(PromptRequest) (PromptResponse, error) { return PromptResponse{StopReason: "end_turn"}, nil }, - CancelFunc: func(CancelNotification) error { return nil }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) error { return nil }, + AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil }, + PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) { + return PromptResponse{StopReason: "end_turn"}, nil + }, + CancelFunc: func(context.Context, CancelNotification) error { return nil }, }, a2cW, c2aR) var wg sync.WaitGroup @@ -191,7 +201,7 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) { req := p go func() { defer wg.Done() - errs[idx] = agentConn.WriteTextFile(req) + errs[idx] = agentConn.WriteTextFile(context.Background(), req) }() } wg.Wait() @@ -218,44 +228,56 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) { push := func(s string) { mu.Lock(); defer mu.Unlock(); log = append(log, s) } cs := NewClientSideConnection(clientFuncs{ - WriteTextFileFunc: func(p WriteTextFileRequest) error { push("writeTextFile called: " + p.Path); return nil }, - ReadTextFileFunc: func(p ReadTextFileRequest) (ReadTextFileResponse, error) { - push("readTextFile called: " + p.Path) + WriteTextFileFunc: func(_ context.Context, req WriteTextFileRequest) error { + push("writeTextFile called: " + req.Path) + return nil + }, + ReadTextFileFunc: func(_ context.Context, req ReadTextFileRequest) (ReadTextFileResponse, error) { + push("readTextFile called: " + req.Path) return ReadTextFileResponse{Content: "test content"}, nil }, - RequestPermissionFunc: func(p RequestPermissionRequest) (RequestPermissionResponse, error) { - push("requestPermission called: " + p.ToolCall.Title) + RequestPermissionFunc: func(_ context.Context, req RequestPermissionRequest) (RequestPermissionResponse, error) { + push("requestPermission called: " + req.ToolCall.Title) return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil }, - SessionUpdateFunc: func(SessionNotification) error { return nil }, + SessionUpdateFunc: func(context.Context, SessionNotification) error { return nil }, }, c2aW, a2cR) as := NewAgentSideConnection(agentFuncs{ - InitializeFunc: func(InitializeRequest) (InitializeResponse, error) { + InitializeFunc: func(context.Context, InitializeRequest) (InitializeResponse, error) { return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil }, - NewSessionFunc: func(p NewSessionRequest) (NewSessionResponse, error) { + NewSessionFunc: func(_ context.Context, p NewSessionRequest) (NewSessionResponse, error) { push("newSession called: " + p.Cwd) return NewSessionResponse{SessionId: "test-session"}, nil }, - LoadSessionFunc: func(p LoadSessionRequest) error { push("loadSession called: " + string(p.SessionId)); return nil }, - AuthenticateFunc: func(p AuthenticateRequest) error { push("authenticate called: " + string(p.MethodId)); return nil }, - PromptFunc: func(p PromptRequest) (PromptResponse, error) { + LoadSessionFunc: func(_ context.Context, p LoadSessionRequest) error { + push("loadSession called: " + string(p.SessionId)) + return nil + }, + AuthenticateFunc: func(_ context.Context, p AuthenticateRequest) error { + push("authenticate called: " + string(p.MethodId)) + return nil + }, + PromptFunc: func(_ context.Context, p PromptRequest) (PromptResponse, error) { push("prompt called: " + string(p.SessionId)) return PromptResponse{StopReason: "end_turn"}, nil }, - CancelFunc: func(p CancelNotification) error { push("cancelled called: " + string(p.SessionId)); return nil }, + CancelFunc: func(_ context.Context, p CancelNotification) error { + push("cancelled called: " + string(p.SessionId)) + return nil + }, }, a2cW, c2aR) - if _, err := cs.NewSession(NewSessionRequest{Cwd: "/test", McpServers: []McpServer{}}); err != nil { + if _, err := cs.NewSession(context.Background(), NewSessionRequest{Cwd: "/test", McpServers: []McpServer{}}); err != nil { t.Fatalf("newSession error: %v", err) } - if err := as.WriteTextFile(WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err != nil { + if err := as.WriteTextFile(context.Background(), WriteTextFileRequest{Path: "/test.txt", Content: "test", SessionId: "test-session"}); err != nil { t.Fatalf("writeTextFile error: %v", err) } - if _, err := as.ReadTextFile(ReadTextFileRequest{Path: "/test.txt", SessionId: "test-session"}); err != nil { + if _, err := as.ReadTextFile(context.Background(), ReadTextFileRequest{Path: "/test.txt", SessionId: "test-session"}); err != nil { t.Fatalf("readTextFile error: %v", err) } - if _, err := as.RequestPermission(RequestPermissionRequest{ + if _, err := as.RequestPermission(context.Background(), RequestPermissionRequest{ SessionId: "test-session", ToolCall: ToolCallUpdate{ Title: "Execute command", @@ -305,14 +327,14 @@ func TestConnectionHandlesNotifications(t *testing.T) { push := func(s string) { mu.Lock(); logs = append(logs, s); mu.Unlock() } clientSide := NewClientSideConnection(clientFuncs{ - WriteTextFileFunc: func(WriteTextFileRequest) error { return nil }, - ReadTextFileFunc: func(ReadTextFileRequest) (ReadTextFileResponse, error) { + WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return nil }, + ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) { return ReadTextFileResponse{Content: "test"}, nil }, - RequestPermissionFunc: func(RequestPermissionRequest) (RequestPermissionResponse, error) { + RequestPermissionFunc: func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) { return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil }, - SessionUpdateFunc: func(n SessionNotification) error { + SessionUpdateFunc: func(_ context.Context, n SessionNotification) error { if n.Update.AgentMessageChunk != nil { if n.Update.AgentMessageChunk.Content.Text != nil { push("agent message: " + n.Update.AgentMessageChunk.Content.Text.Text) @@ -325,26 +347,31 @@ func TestConnectionHandlesNotifications(t *testing.T) { }, }, c2aW, a2cR) agentSide := NewAgentSideConnection(agentFuncs{ - InitializeFunc: func(InitializeRequest) (InitializeResponse, error) { + InitializeFunc: func(context.Context, InitializeRequest) (InitializeResponse, error) { return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil }, - NewSessionFunc: func(NewSessionRequest) (NewSessionResponse, error) { + NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{SessionId: "test-session"}, nil }, - LoadSessionFunc: func(LoadSessionRequest) error { return nil }, - AuthenticateFunc: func(AuthenticateRequest) error { return nil }, - PromptFunc: func(PromptRequest) (PromptResponse, error) { return PromptResponse{StopReason: "end_turn"}, nil }, - CancelFunc: func(p CancelNotification) error { push("cancelled: " + string(p.SessionId)); return nil }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) error { return nil }, + AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil }, + PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) { + return PromptResponse{StopReason: "end_turn"}, nil + }, + CancelFunc: func(_ context.Context, p CancelNotification) error { + push("cancelled: " + string(p.SessionId)) + return nil + }, }, a2cW, c2aR) - if err := agentSide.SessionUpdate(SessionNotification{ + if err := agentSide.SessionUpdate(context.Background(), SessionNotification{ SessionId: "test-session", Update: SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: ContentBlock{Type: "text", Text: &TextContent{Text: "Hello from agent"}}}}, }); err != nil { t.Fatalf("sessionUpdate error: %v", err) } - if err := clientSide.Cancel(CancelNotification{SessionId: "test-session"}); err != nil { + if err := clientSide.Cancel(context.Background(), CancelNotification{SessionId: "test-session"}); err != nil { t.Fatalf("cancel error: %v", err) } @@ -365,29 +392,31 @@ func TestConnectionHandlesInitialize(t *testing.T) { a2cR, a2cW := io.Pipe() agentConn := NewClientSideConnection(clientFuncs{ - WriteTextFileFunc: func(WriteTextFileRequest) error { return nil }, - ReadTextFileFunc: func(ReadTextFileRequest) (ReadTextFileResponse, error) { + WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return nil }, + ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) { return ReadTextFileResponse{Content: "test"}, nil }, - RequestPermissionFunc: func(RequestPermissionRequest) (RequestPermissionResponse, error) { + RequestPermissionFunc: func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) { return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil }, - SessionUpdateFunc: func(SessionNotification) error { return nil }, + SessionUpdateFunc: func(context.Context, SessionNotification) error { return nil }, }, c2aW, a2cR) _ = NewAgentSideConnection(agentFuncs{ - InitializeFunc: func(p InitializeRequest) (InitializeResponse, error) { + InitializeFunc: func(_ context.Context, p InitializeRequest) (InitializeResponse, error) { return InitializeResponse{ProtocolVersion: p.ProtocolVersion, AgentCapabilities: AgentCapabilities{LoadSession: true}, AuthMethods: []AuthMethod{{Id: "oauth", Name: "OAuth", Description: "Authenticate with OAuth"}}}, nil }, - NewSessionFunc: func(NewSessionRequest) (NewSessionResponse, error) { + NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{SessionId: "test-session"}, nil }, - LoadSessionFunc: func(LoadSessionRequest) error { return nil }, - AuthenticateFunc: func(AuthenticateRequest) error { return nil }, - PromptFunc: func(PromptRequest) (PromptResponse, error) { return PromptResponse{StopReason: "end_turn"}, nil }, - CancelFunc: func(CancelNotification) error { return nil }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) error { return nil }, + AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil }, + PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) { + return PromptResponse{StopReason: "end_turn"}, nil + }, + CancelFunc: func(context.Context, CancelNotification) error { return nil }, }, a2cW, c2aR) - resp, err := agentConn.Initialize(InitializeRequest{ + resp, err := agentConn.Initialize(context.Background(), InitializeRequest{ ProtocolVersion: ProtocolVersionNumber, ClientCapabilities: ClientCapabilities{Fs: FileSystemCapability{ReadTextFile: false, WriteTextFile: false}}, }) @@ -408,3 +437,95 @@ func TestConnectionHandlesInitialize(t *testing.T) { func ptr[T any](t T) *T { return &t } + +// Test that canceling the client's Prompt context sends a session/cancel +// to the agent, and that the connection remains usable afterwards. +func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) { + c2aR, c2aW := io.Pipe() + a2cR, a2cW := io.Pipe() + + cancelCh := make(chan string, 1) + promptDone := make(chan struct{}, 1) + + // Agent side: Prompt waits for ctx cancellation; Cancel records the sessionId + _ = NewAgentSideConnection(agentFuncs{ + InitializeFunc: func(context.Context, InitializeRequest) (InitializeResponse, error) { + return InitializeResponse{ProtocolVersion: ProtocolVersionNumber}, nil + }, + NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { + return NewSessionResponse{SessionId: "s-1"}, nil + }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) error { return nil }, + AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil }, + PromptFunc: func(ctx context.Context, p PromptRequest) (PromptResponse, error) { + <-ctx.Done() + // mark that prompt finished due to cancellation + select { + case promptDone <- struct{}{}: + default: + } + return PromptResponse{StopReason: StopReasonCancelled}, nil + }, + CancelFunc: func(context.Context, CancelNotification) error { + select { + case cancelCh <- "s-1": + default: + } + return nil + }, + }, a2cW, c2aR) + + // Client side + cs := NewClientSideConnection(clientFuncs{ + WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return nil }, + ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) { + return ReadTextFileResponse{Content: ""}, nil + }, + RequestPermissionFunc: func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) { + return RequestPermissionResponse{}, nil + }, + SessionUpdateFunc: func(context.Context, SessionNotification) error { return nil }, + }, c2aW, a2cR) + + // Initialize and create a session + if _, err := cs.Initialize(context.Background(), InitializeRequest{ProtocolVersion: ProtocolVersionNumber}); err != nil { + t.Fatalf("initialize: %v", err) + } + sess, err := cs.NewSession(context.Background(), NewSessionRequest{Cwd: "/", McpServers: []McpServer{}}) + if err != nil { + t.Fatalf("newSession: %v", err) + } + + // Start a prompt with a cancelable context, then cancel it + turnCtx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + go func() { + _, err := cs.Prompt(turnCtx, PromptRequest{SessionId: sess.SessionId, Prompt: []ContentBlock{TextBlock("hello")}}) + errCh <- err + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + // Expect a session/cancel notification on the agent side + select { + case sid := <-cancelCh: + if sid != string(sess.SessionId) && sid != "s-1" { // allow either depending on agent NewSession response + t.Fatalf("unexpected cancel session id: %q", sid) + } + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for session/cancel") + } + + // Agent's prompt should have finished due to ctx cancellation + select { + case <-promptDone: + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for prompt to finish after cancel") + } + + // Connection remains usable: create another session + if _, err := cs.NewSession(context.Background(), NewSessionRequest{Cwd: "/", McpServers: []McpServer{}}); err != nil { + t.Fatalf("newSession after cancel: %v", err) + } +} diff --git a/go/agent.go b/go/agent.go index 18ad69f..b32ad16 100644 --- a/go/agent.go +++ b/go/agent.go @@ -1,13 +1,18 @@ package acp import ( + "context" "io" + "sync" ) // AgentSideConnection represents the agent's view of a connection to a client. type AgentSideConnection struct { conn *Connection agent Agent + + mu sync.Mutex + sessionCancels map[string]context.CancelFunc } // NewAgentSideConnection creates a new agent-side connection bound to the @@ -15,6 +20,7 @@ type AgentSideConnection struct { func NewAgentSideConnection(agent Agent, peerInput io.Writer, peerOutput io.Reader) *AgentSideConnection { asc := &AgentSideConnection{} asc.agent = agent + asc.sessionCancels = make(map[string]context.CancelFunc) asc.conn = NewConnection(asc.handle, peerInput, peerOutput) return asc } diff --git a/go/agent_gen.go b/go/agent_gen.go index 9897a38..2a80522 100644 --- a/go/agent_gen.go +++ b/go/agent_gen.go @@ -2,9 +2,12 @@ package acp -import "encoding/json" +import ( + "context" + "encoding/json" +) -func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any, *RequestError) { +func (a *AgentSideConnection) handle(ctx context.Context, method string, params json.RawMessage) (any, *RequestError) { switch method { case AgentMethodAuthenticate: var p AuthenticateRequest @@ -14,7 +17,7 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - if err := a.agent.Authenticate(p); err != nil { + if err := a.agent.Authenticate(ctx, p); err != nil { return nil, toReqErr(err) } return nil, nil @@ -26,7 +29,7 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - resp, err := a.agent.Initialize(p) + resp, err := a.agent.Initialize(ctx, p) if err != nil { return nil, toReqErr(err) } @@ -39,7 +42,13 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - if err := a.agent.Cancel(p); err != nil { + a.mu.Lock() + if cn, ok := a.sessionCancels[string(p.SessionId)]; ok { + cn() + delete(a.sessionCancels, string(p.SessionId)) + } + a.mu.Unlock() + if err := a.agent.Cancel(ctx, p); err != nil { return nil, toReqErr(err) } return nil, nil @@ -55,7 +64,7 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if !ok { return nil, NewMethodNotFound(method) } - if err := loader.LoadSession(p); err != nil { + if err := loader.LoadSession(ctx, p); err != nil { return nil, toReqErr(err) } return nil, nil @@ -67,7 +76,7 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - resp, err := a.agent.NewSession(p) + resp, err := a.agent.NewSession(ctx, p) if err != nil { return nil, toReqErr(err) } @@ -80,7 +89,19 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - resp, err := a.agent.Prompt(p) + var reqCtx context.Context + var cancel context.CancelFunc + reqCtx, cancel = context.WithCancel(ctx) + a.mu.Lock() + if prev, ok := a.sessionCancels[string(p.SessionId)]; ok { + prev() + } + a.sessionCancels[string(p.SessionId)] = cancel + a.mu.Unlock() + resp, err := a.agent.Prompt(reqCtx, p) + a.mu.Lock() + delete(a.sessionCancels, string(p.SessionId)) + a.mu.Unlock() if err != nil { return nil, toReqErr(err) } @@ -89,32 +110,32 @@ func (a *AgentSideConnection) handle(method string, params json.RawMessage) (any return nil, NewMethodNotFound(method) } } -func (c *AgentSideConnection) ReadTextFile(params ReadTextFileRequest) (ReadTextFileResponse, error) { - resp, err := SendRequest[ReadTextFileResponse](c.conn, ClientMethodFsReadTextFile, params) +func (c *AgentSideConnection) ReadTextFile(ctx context.Context, params ReadTextFileRequest) (ReadTextFileResponse, error) { + resp, err := SendRequest[ReadTextFileResponse](c.conn, ctx, ClientMethodFsReadTextFile, params) return resp, err } -func (c *AgentSideConnection) WriteTextFile(params WriteTextFileRequest) error { - return c.conn.SendRequestNoResult(ClientMethodFsWriteTextFile, params) +func (c *AgentSideConnection) WriteTextFile(ctx context.Context, params WriteTextFileRequest) error { + return c.conn.SendRequestNoResult(ctx, ClientMethodFsWriteTextFile, params) } -func (c *AgentSideConnection) RequestPermission(params RequestPermissionRequest) (RequestPermissionResponse, error) { - resp, err := SendRequest[RequestPermissionResponse](c.conn, ClientMethodSessionRequestPermission, params) +func (c *AgentSideConnection) RequestPermission(ctx context.Context, params RequestPermissionRequest) (RequestPermissionResponse, error) { + resp, err := SendRequest[RequestPermissionResponse](c.conn, ctx, ClientMethodSessionRequestPermission, params) return resp, err } -func (c *AgentSideConnection) SessionUpdate(params SessionNotification) error { - return c.conn.SendNotification(ClientMethodSessionUpdate, params) +func (c *AgentSideConnection) SessionUpdate(ctx context.Context, params SessionNotification) error { + return c.conn.SendNotification(ctx, ClientMethodSessionUpdate, params) } -func (c *AgentSideConnection) CreateTerminal(params CreateTerminalRequest) (CreateTerminalResponse, error) { - resp, err := SendRequest[CreateTerminalResponse](c.conn, ClientMethodTerminalCreate, params) +func (c *AgentSideConnection) CreateTerminal(ctx context.Context, params CreateTerminalRequest) (CreateTerminalResponse, error) { + resp, err := SendRequest[CreateTerminalResponse](c.conn, ctx, ClientMethodTerminalCreate, params) return resp, err } -func (c *AgentSideConnection) TerminalOutput(params TerminalOutputRequest) (TerminalOutputResponse, error) { - resp, err := SendRequest[TerminalOutputResponse](c.conn, ClientMethodTerminalOutput, params) +func (c *AgentSideConnection) TerminalOutput(ctx context.Context, params TerminalOutputRequest) (TerminalOutputResponse, error) { + resp, err := SendRequest[TerminalOutputResponse](c.conn, ctx, ClientMethodTerminalOutput, params) return resp, err } -func (c *AgentSideConnection) ReleaseTerminal(params ReleaseTerminalRequest) error { - return c.conn.SendRequestNoResult(ClientMethodTerminalRelease, params) +func (c *AgentSideConnection) ReleaseTerminal(ctx context.Context, params ReleaseTerminalRequest) error { + return c.conn.SendRequestNoResult(ctx, ClientMethodTerminalRelease, params) } -func (c *AgentSideConnection) WaitForTerminalExit(params WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) { - resp, err := SendRequest[WaitForTerminalExitResponse](c.conn, ClientMethodTerminalWaitForExit, params) +func (c *AgentSideConnection) WaitForTerminalExit(ctx context.Context, params WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) { + resp, err := SendRequest[WaitForTerminalExitResponse](c.conn, ctx, ClientMethodTerminalWaitForExit, params) return resp, err } diff --git a/go/client_gen.go b/go/client_gen.go index 9ea61aa..ead5e0a 100644 --- a/go/client_gen.go +++ b/go/client_gen.go @@ -2,9 +2,12 @@ package acp -import "encoding/json" +import ( + "context" + "encoding/json" +) -func (c *ClientSideConnection) handle(method string, params json.RawMessage) (any, *RequestError) { +func (c *ClientSideConnection) handle(ctx context.Context, method string, params json.RawMessage) (any, *RequestError) { switch method { case ClientMethodFsReadTextFile: var p ReadTextFileRequest @@ -14,7 +17,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - resp, err := c.client.ReadTextFile(p) + resp, err := c.client.ReadTextFile(ctx, p) if err != nil { return nil, toReqErr(err) } @@ -27,7 +30,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - if err := c.client.WriteTextFile(p); err != nil { + if err := c.client.WriteTextFile(ctx, p); err != nil { return nil, toReqErr(err) } return nil, nil @@ -39,7 +42,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - resp, err := c.client.RequestPermission(p) + resp, err := c.client.RequestPermission(ctx, p) if err != nil { return nil, toReqErr(err) } @@ -52,7 +55,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - if err := c.client.SessionUpdate(p); err != nil { + if err := c.client.SessionUpdate(ctx, p); err != nil { return nil, toReqErr(err) } return nil, nil @@ -68,7 +71,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if !ok { return nil, NewMethodNotFound(method) } - resp, err := t.CreateTerminal(p) + resp, err := t.CreateTerminal(ctx, p) if err != nil { return nil, toReqErr(err) } @@ -85,7 +88,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if !ok { return nil, NewMethodNotFound(method) } - resp, err := t.TerminalOutput(p) + resp, err := t.TerminalOutput(ctx, p) if err != nil { return nil, toReqErr(err) } @@ -102,7 +105,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if !ok { return nil, NewMethodNotFound(method) } - if err := t.ReleaseTerminal(p); err != nil { + if err := t.ReleaseTerminal(ctx, p); err != nil { return nil, toReqErr(err) } return nil, nil @@ -118,7 +121,7 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an if !ok { return nil, NewMethodNotFound(method) } - resp, err := t.WaitForTerminalExit(p) + resp, err := t.WaitForTerminalExit(ctx, p) if err != nil { return nil, toReqErr(err) } @@ -127,24 +130,29 @@ func (c *ClientSideConnection) handle(method string, params json.RawMessage) (an return nil, NewMethodNotFound(method) } } -func (c *ClientSideConnection) Authenticate(params AuthenticateRequest) error { - return c.conn.SendRequestNoResult(AgentMethodAuthenticate, params) +func (c *ClientSideConnection) Authenticate(ctx context.Context, params AuthenticateRequest) error { + return c.conn.SendRequestNoResult(ctx, AgentMethodAuthenticate, params) } -func (c *ClientSideConnection) Initialize(params InitializeRequest) (InitializeResponse, error) { - resp, err := SendRequest[InitializeResponse](c.conn, AgentMethodInitialize, params) +func (c *ClientSideConnection) Initialize(ctx context.Context, params InitializeRequest) (InitializeResponse, error) { + resp, err := SendRequest[InitializeResponse](c.conn, ctx, AgentMethodInitialize, params) return resp, err } -func (c *ClientSideConnection) Cancel(params CancelNotification) error { - return c.conn.SendNotification(AgentMethodSessionCancel, params) +func (c *ClientSideConnection) Cancel(ctx context.Context, params CancelNotification) error { + return c.conn.SendNotification(ctx, AgentMethodSessionCancel, params) } -func (c *ClientSideConnection) LoadSession(params LoadSessionRequest) error { - return c.conn.SendRequestNoResult(AgentMethodSessionLoad, params) +func (c *ClientSideConnection) LoadSession(ctx context.Context, params LoadSessionRequest) error { + return c.conn.SendRequestNoResult(ctx, AgentMethodSessionLoad, params) } -func (c *ClientSideConnection) NewSession(params NewSessionRequest) (NewSessionResponse, error) { - resp, err := SendRequest[NewSessionResponse](c.conn, AgentMethodSessionNew, params) +func (c *ClientSideConnection) NewSession(ctx context.Context, params NewSessionRequest) (NewSessionResponse, error) { + resp, err := SendRequest[NewSessionResponse](c.conn, ctx, AgentMethodSessionNew, params) return resp, err } -func (c *ClientSideConnection) Prompt(params PromptRequest) (PromptResponse, error) { - resp, err := SendRequest[PromptResponse](c.conn, AgentMethodSessionPrompt, params) +func (c *ClientSideConnection) Prompt(ctx context.Context, params PromptRequest) (PromptResponse, error) { + resp, err := SendRequest[PromptResponse](c.conn, ctx, AgentMethodSessionPrompt, params) + if err != nil { + if ctx.Err() != nil { + _ = c.Cancel(context.Background(), CancelNotification{SessionId: params.SessionId}) + } + } return resp, err } diff --git a/go/cmd/generate/internal/emit/dispatch.go b/go/cmd/generate/internal/emit/dispatch.go index 3a144cd..7b331fe 100644 --- a/go/cmd/generate/internal/emit/dispatch.go +++ b/go/cmd/generate/internal/emit/dispatch.go @@ -34,6 +34,18 @@ func WriteDispatchJen(outDir string, schema *load.Schema, meta *load.Meta) error caseBody := []Code{} if mi.Notif != "" { caseBody = append(caseBody, jUnmarshalValidate(mi.Notif)...) + // Special-case: session/cancel should also cancel any in-flight prompt ctx for the session. + if mi.Method == "session/cancel" { + caseBody = append(caseBody, + // cancel active prompt context if present + Id("a").Dot("mu").Dot("Lock").Call(), + If(List(Id("cn"), Id("ok")).Op(":=").Id("a").Dot("sessionCancels").Index(Id("string").Call(Id("p").Dot("SessionId"))), Id("ok")).Block( + Id("cn").Call(), + Id("delete").Call(Id("a").Dot("sessionCancels"), Id("string").Call(Id("p").Dot("SessionId"))), + ), + Id("a").Dot("mu").Dot("Unlock").Call(), + ) + } callName := ir.DispatchMethodNameForNotification(k, mi.Notif) caseBody = append(caseBody, jCallNotification("a.agent", callName)...) } else if mi.Req != "" { @@ -44,7 +56,27 @@ func WriteDispatchJen(outDir string, schema *load.Schema, meta *load.Meta) error if pre != nil { caseBody = append(caseBody, pre...) } - if ir.IsNullResponse(schema.Defs[respName]) { + if mi.Method == "session/prompt" { + // Derive a cancellable context per session prompt. + caseBody = append(caseBody, + Var().Id("reqCtx").Qual("context", "Context"), Var().Id("cancel").Qual("context", "CancelFunc"), + List(Id("reqCtx"), Id("cancel")).Op("=").Qual("context", "WithCancel").Call(Id("ctx")), + Id("a").Dot("mu").Dot("Lock").Call(), + If(List(Id("prev"), Id("ok")).Op(":=").Id("a").Dot("sessionCancels").Index(Id("string").Call(Id("p").Dot("SessionId"))), Id("ok")).Block(Id("prev").Call()), + Id("a").Dot("sessionCancels").Index(Id("string").Call(Id("p").Dot("SessionId"))).Op("=").Id("cancel"), + Id("a").Dot("mu").Dot("Unlock").Call(), + ) + // Call agent.Prompt(reqCtx, p) + caseBody = append(caseBody, + List(Id("resp"), Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("reqCtx"), Id("p")), + // cleanup entry after return + Id("a").Dot("mu").Dot("Lock").Call(), + Id("delete").Call(Id("a").Dot("sessionCancels"), Id("string").Call(Id("p").Dot("SessionId"))), + Id("a").Dot("mu").Dot("Unlock").Call(), + If(Id("err").Op("!=").Nil()).Block(jRetToReqErr()), + Return(Id("resp"), Nil()), + ) + } else if ir.IsNullResponse(schema.Defs[respName]) { caseBody = append(caseBody, jCallRequestNoResp(recv, methodName)...) } else { caseBody = append(caseBody, jCallRequestWithResp(recv, methodName, respName)...) @@ -56,7 +88,7 @@ func WriteDispatchJen(outDir string, schema *load.Schema, meta *load.Meta) error } switchCases = append(switchCases, Default().Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))))) fAgent.Func().Params(Id("a").Op("*").Id("AgentSideConnection")).Id("handle").Params( - Id("method").String(), Id("params").Qual("encoding/json", "RawMessage")). + Id("ctx").Qual("context", "Context"), Id("method").String(), Id("params").Qual("encoding/json", "RawMessage")). Params(Any(), Op("*").Id("RequestError")). Block(Switch(Id("method")).Block(switchCases...)) @@ -93,19 +125,19 @@ func WriteDispatchJen(outDir string, schema *load.Schema, meta *load.Meta) error case "session/cancel": name = "Cancel" } - fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(name).Params(Id("params").Id(mi.Notif)).Error(). - Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id(constName), Id("params")))) + fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(name).Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Notif)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id("ctx"), Id(constName), Id("params")))) } else if mi.Req != "" { respName := strings.TrimSuffix(mi.Req, "Request") + "Response" if ir.IsNullResponse(schema.Defs[respName]) { fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). - Params(Id("params").Id(mi.Req)).Error(). - Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id(constName), Id("params")))) + Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Req)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id("ctx"), Id(constName), Id("params")))) } else { fAgent.Func().Params(Id("c").Op("*").Id("AgentSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). - Params(Id("params").Id(mi.Req)).Params(Id(respName), Error()). + Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Req)).Params(Id(respName), Error()). Block( - List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id(constName), Id("params")), + List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id("ctx"), Id(constName), Id("params")), Return(Id("resp"), Id("err")), ) } @@ -163,7 +195,7 @@ func WriteDispatchJen(outDir string, schema *load.Schema, meta *load.Meta) error } cCases = append(cCases, Default().Block(Return(Nil(), Id("NewMethodNotFound").Call(Id("method"))))) fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id("handle").Params( - Id("method").String(), Id("params").Qual("encoding/json", "RawMessage")). + Id("ctx").Qual("context", "Context"), Id("method").String(), Id("params").Qual("encoding/json", "RawMessage")). Params(Any(), Op("*").Id("RequestError")). Block(Switch(Id("method")).Block(cCases...)) @@ -191,21 +223,36 @@ func WriteDispatchJen(outDir string, schema *load.Schema, meta *load.Meta) error case "session/cancel": name = "Cancel" } - fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(name).Params(Id("params").Id(mi.Notif)).Error(). - Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id(constName), Id("params")))) + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(name).Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Notif)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendNotification").Call(Id("ctx"), Id(constName), Id("params")))) } else if mi.Req != "" { respName := strings.TrimSuffix(mi.Req, "Request") + "Response" if ir.IsNullResponse(schema.Defs[respName]) { fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). - Params(Id("params").Id(mi.Req)).Error(). - Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id(constName), Id("params")))) + Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Req)).Error(). + Block(Return(Id("c").Dot("conn").Dot("SendRequestNoResult").Call(Id("ctx"), Id(constName), Id("params")))) } else { - fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). - Params(Id("params").Id(mi.Req)).Params(Id(respName), Error()). - Block( - List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id(constName), Id("params")), - Return(Id("resp"), Id("err")), - ) + // Special-case: session/prompt — if ctx was canceled, send session/cancel best-effort. + if mi.Method == "session/prompt" { + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Req)).Params(Id(respName), Error()). + Block( + List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id("ctx"), Id(constName), Id("params")), + If(Id("err").Op("!=").Nil()).Block( + If(Id("ctx").Dot("Err").Call().Op("!=").Nil()).Block( + Id("_ ").Op("=").Id("c").Dot("Cancel").Call(Qual("context", "Background").Call(), Id("CancelNotification").Values(Dict{Id("SessionId"): Id("params").Dot("SessionId")})), + ), + ), + Return(Id("resp"), Id("err")), + ) + } else { + fClient.Func().Params(Id("c").Op("*").Id("ClientSideConnection")).Id(strings.TrimSuffix(mi.Req, "Request")). + Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Req)).Params(Id(respName), Error()). + Block( + List(Id("resp"), Id("err")).Op(":=").Id("SendRequest").Types(Id(respName)).Call(Id("c").Dot("conn"), Id("ctx"), Id(constName), Id("params")), + Return(Id("resp"), Id("err")), + ) + } } } } diff --git a/go/cmd/generate/internal/emit/dispatch_helpers.go b/go/cmd/generate/internal/emit/dispatch_helpers.go index 98c7ed3..3167163 100644 --- a/go/cmd/generate/internal/emit/dispatch_helpers.go +++ b/go/cmd/generate/internal/emit/dispatch_helpers.go @@ -57,14 +57,14 @@ func jClientAssert(binding ir.MethodBinding) ([]Code, string) { // Request call emitters for handlers func jCallRequestNoResp(recv, methodName string) []Code { return []Code{ - If(List(Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block(jRetToReqErr()), + If(List(Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("ctx"), Id("p")), Id("err").Op("!=").Nil()).Block(jRetToReqErr()), Return(Nil(), Nil()), } } func jCallRequestWithResp(recv, methodName, respType string) []Code { return []Code{ - List(Id("resp"), Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("p")), + List(Id("resp"), Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("ctx"), Id("p")), If(Id("err").Op("!=").Nil()).Block(jRetToReqErr()), Return(Id("resp"), Nil()), } @@ -72,7 +72,7 @@ func jCallRequestWithResp(recv, methodName, respType string) []Code { func jCallNotification(recv, methodName string) []Code { return []Code{ - If(List(Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("p")), Id("err").Op("!=").Nil()).Block(jRetToReqErr()), + If(List(Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("ctx"), Id("p")), Id("err").Op("!=").Nil()).Block(jRetToReqErr()), Return(Nil(), Nil()), } } diff --git a/go/cmd/generate/internal/emit/types.go b/go/cmd/generate/internal/emit/types.go index d5bdc53..fce7502 100644 --- a/go/cmd/generate/internal/emit/types.go +++ b/go/cmd/generate/internal/emit/types.go @@ -132,14 +132,14 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { } if mi.Notif != "" { name := ir.DispatchMethodNameForNotification(k, mi.Notif) - *target = append(*target, Id(name).Params(Id("params").Id(mi.Notif)).Error()) + *target = append(*target, Id(name).Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Notif)).Error()) } else if mi.Req != "" { respName := strings.TrimSuffix(mi.Req, "Request") + "Response" methodName := strings.TrimSuffix(mi.Req, "Request") if ir.IsNullResponse(schema.Defs[respName]) { - *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) + *target = append(*target, Id(methodName).Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Req)).Error()) } else { - *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) + *target = append(*target, Id(methodName).Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Req)).Params(Id(respName), Error())) } } } @@ -172,14 +172,14 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { } if mi.Notif != "" { name := ir.DispatchMethodNameForNotification(k, mi.Notif) - *target = append(*target, Id(name).Params(Id("params").Id(mi.Notif)).Error()) + *target = append(*target, Id(name).Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Notif)).Error()) } else if mi.Req != "" { respName := strings.TrimSuffix(mi.Req, "Request") + "Response" methodName := strings.TrimSuffix(mi.Req, "Request") if ir.IsNullResponse(schema.Defs[respName]) { - *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Error()) + *target = append(*target, Id(methodName).Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Req)).Error()) } else { - *target = append(*target, Id(methodName).Params(Id("params").Id(mi.Req)).Params(Id(respName), Error())) + *target = append(*target, Id(methodName).Params(Id("ctx").Qual("context", "Context"), Id("params").Id(mi.Req)).Params(Id(respName), Error())) } } } diff --git a/go/connection.go b/go/connection.go index f6d7e66..1898b71 100644 --- a/go/connection.go +++ b/go/connection.go @@ -2,6 +2,7 @@ package acp import ( "bufio" + "context" "encoding/json" "io" "sync" @@ -21,7 +22,7 @@ type pendingResponse struct { ch chan anyMessage } -type MethodHandler func(method string, params json.RawMessage) (any, *RequestError) +type MethodHandler func(ctx context.Context, method string, params json.RawMessage) (any, *RequestError) // Connection is a simple JSON-RPC 2.0 connection over line-delimited JSON. type Connection struct { @@ -92,6 +93,12 @@ func (c *Connection) receive() { } func (c *Connection) handleInbound(req *anyMessage) { + // Context that cancels when the connection is closed + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-c.Done() + cancel() + }() res := anyMessage{JSONRPC: "2.0"} // copy ID if present if req.ID != nil { @@ -105,7 +112,7 @@ func (c *Connection) handleInbound(req *anyMessage) { return } - result, err := c.handler(req.Method, req.Params) + result, err := c.handler(ctx, req.Method, req.Params) if req.ID == nil { // notification: nothing to send return @@ -139,7 +146,7 @@ func (c *Connection) sendMessage(msg anyMessage) error { // SendRequest sends a JSON-RPC request and returns a typed result. // For methods that do not return a result, use SendRequestNoResult instead. -func SendRequest[T any](c *Connection, method string, params any) (T, error) { +func SendRequest[T any](c *Connection, ctx context.Context, method string, params any) (T, error) { var zero T // allocate id id := c.nextID.Add(1) @@ -169,6 +176,12 @@ func SendRequest[T any](c *Connection, method string, params any) (T, error) { d := c.Done() select { case resp = <-pr.ch: + case <-ctx.Done(): + // best-effort cleanup + c.mu.Lock() + delete(c.pending, idKey) + c.mu.Unlock() + return zero, NewInternalError(map[string]any{"error": ctx.Err().Error()}) case <-d: return zero, NewInternalError(map[string]any{"error": "peer disconnected before response"}) } @@ -185,7 +198,7 @@ func SendRequest[T any](c *Connection, method string, params any) (T, error) { } // SendRequestNoResult sends a JSON-RPC request that returns no result payload. -func (c *Connection) SendRequestNoResult(method string, params any) error { +func (c *Connection) SendRequestNoResult(ctx context.Context, method string, params any) error { // allocate id id := c.nextID.Add(1) idRaw, _ := json.Marshal(id) @@ -213,6 +226,11 @@ func (c *Connection) SendRequestNoResult(method string, params any) error { d := c.Done() select { case resp = <-pr.ch: + case <-ctx.Done(): + c.mu.Lock() + delete(c.pending, idKey) + c.mu.Unlock() + return NewInternalError(map[string]any{"error": ctx.Err().Error()}) case <-d: return NewInternalError(map[string]any{"error": "peer disconnected before response"}) } @@ -222,7 +240,12 @@ func (c *Connection) SendRequestNoResult(method string, params any) error { return nil } -func (c *Connection) SendNotification(method string, params any) error { +func (c *Connection) SendNotification(ctx context.Context, method string, params any) error { + select { + case <-ctx.Done(): + return NewInternalError(map[string]any{"error": ctx.Err().Error()}) + default: + } msg := anyMessage{JSONRPC: "2.0", Method: method} if params != nil { b, err := json.Marshal(params) diff --git a/go/example/agent/main.go b/go/example/agent/main.go index a046732..6942f4a 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -33,7 +33,7 @@ func newExampleAgent() *exampleAgent { // Implement acp.AgentConnAware to receive the connection after construction. func (a *exampleAgent) SetAgentConnection(conn *acp.AgentSideConnection) { a.conn = conn } -func (a *exampleAgent) Initialize(params acp.InitializeRequest) (acp.InitializeResponse, error) { +func (a *exampleAgent) Initialize(ctx context.Context, params acp.InitializeRequest) (acp.InitializeResponse, error) { return acp.InitializeResponse{ ProtocolVersion: acp.ProtocolVersionNumber, AgentCapabilities: acp.AgentCapabilities{ @@ -42,17 +42,17 @@ func (a *exampleAgent) Initialize(params acp.InitializeRequest) (acp.InitializeR }, nil } -func (a *exampleAgent) NewSession(params acp.NewSessionRequest) (acp.NewSessionResponse, error) { +func (a *exampleAgent) NewSession(ctx context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) { sid := randomID() a.sessions[sid] = &agentSession{} return acp.NewSessionResponse{SessionId: acp.SessionId(sid)}, nil } -func (a *exampleAgent) Authenticate(_ acp.AuthenticateRequest) error { return nil } +func (a *exampleAgent) Authenticate(ctx context.Context, _ acp.AuthenticateRequest) error { return nil } -func (a *exampleAgent) LoadSession(_ acp.LoadSessionRequest) error { return nil } +func (a *exampleAgent) LoadSession(ctx context.Context, _ acp.LoadSessionRequest) error { return nil } -func (a *exampleAgent) Cancel(params acp.CancelNotification) error { +func (a *exampleAgent) Cancel(ctx context.Context, params acp.CancelNotification) error { if s, ok := a.sessions[string(params.SessionId)]; ok { if s.cancel != nil { s.cancel() @@ -61,7 +61,7 @@ func (a *exampleAgent) Cancel(params acp.CancelNotification) error { return nil } -func (a *exampleAgent) Prompt(params acp.PromptRequest) (acp.PromptResponse, error) { +func (a *exampleAgent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.PromptResponse, error) { sid := string(params.SessionId) s, ok := a.sessions[sid] if !ok { @@ -88,7 +88,7 @@ func (a *exampleAgent) Prompt(params acp.PromptRequest) (acp.PromptResponse, err func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { // disclaimer: stream a demo notice so clients see it's the example agent - if err := a.conn.SessionUpdate(acp.SessionNotification{ + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ @@ -102,7 +102,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { return err } // initial message chunk - if err := a.conn.SessionUpdate(acp.SessionNotification{ + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ @@ -117,7 +117,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } // tool call without permission - if err := a.conn.SessionUpdate(acp.SessionNotification{ + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ToolCall: &acp.SessionUpdateToolCall{ ToolCallId: acp.ToolCallId("call_1"), @@ -135,7 +135,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } // update tool call completed - if err := a.conn.SessionUpdate(acp.SessionNotification{ + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ToolCallUpdate: &acp.SessionUpdateToolCallUpdate{ ToolCallId: acp.ToolCallId("call_1"), @@ -154,7 +154,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } // more text - if err := a.conn.SessionUpdate(acp.SessionNotification{ + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ @@ -169,7 +169,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } // tool call requiring permission - if err := a.conn.SessionUpdate(acp.SessionNotification{ + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ToolCall: &acp.SessionUpdateToolCall{ ToolCallId: acp.ToolCallId("call_2"), @@ -184,7 +184,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } // request permission for sensitive operation - permResp, err := a.conn.RequestPermission(acp.RequestPermissionRequest{ + permResp, err := a.conn.RequestPermission(ctx, acp.RequestPermissionRequest{ SessionId: acp.SessionId(sid), ToolCall: acp.ToolCallUpdate{ ToolCallId: acp.ToolCallId("call_2"), @@ -212,7 +212,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } switch string(permResp.Outcome.Selected.OptionId) { case "allow": - if err := a.conn.SessionUpdate(acp.SessionNotification{ + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ToolCallUpdate: &acp.SessionUpdateToolCallUpdate{ ToolCallId: acp.ToolCallId("call_2"), @@ -225,7 +225,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { if err := pause(ctx, time.Second); err != nil { return err } - if err := a.conn.SessionUpdate(acp.SessionNotification{ + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ @@ -239,7 +239,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { if err := pause(ctx, time.Second); err != nil { return err } - if err := a.conn.SessionUpdate(acp.SessionNotification{ + if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), Update: acp.SessionUpdate{ AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ diff --git a/go/example/claude-code/main.go b/go/example/claude-code/main.go index a6ebdbe..c7edb41 100644 --- a/go/example/claude-code/main.go +++ b/go/example/claude-code/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "encoding/json" "flag" "fmt" @@ -22,7 +23,7 @@ type replClient struct { var _ acp.Client = (*replClient)(nil) -func (c *replClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { +func (c *replClient) RequestPermission(ctx context.Context, params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { if c.autoApprove { // Prefer an allow option if present; otherwise choose the first option. for _, o := range params.Options { @@ -59,7 +60,7 @@ func (c *replClient) RequestPermission(params acp.RequestPermissionRequest) (acp } } -func (c *replClient) SessionUpdate(params acp.SessionNotification) error { +func (c *replClient) SessionUpdate(ctx context.Context, params acp.SessionNotification) error { u := params.Update switch { case u.AgentMessageChunk != nil: @@ -88,7 +89,7 @@ func (c *replClient) SessionUpdate(params acp.SessionNotification) error { return nil } -func (c *replClient) WriteTextFile(params acp.WriteTextFileRequest) error { +func (c *replClient) WriteTextFile(ctx context.Context, params acp.WriteTextFileRequest) error { if !filepath.IsAbs(params.Path) { return fmt.Errorf("path must be absolute: %s", params.Path) } @@ -105,7 +106,7 @@ func (c *replClient) WriteTextFile(params acp.WriteTextFileRequest) error { return nil } -func (c *replClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { +func (c *replClient) ReadTextFile(ctx context.Context, params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { if !filepath.IsAbs(params.Path) { return acp.ReadTextFileResponse{}, fmt.Errorf("path must be absolute: %s", params.Path) } @@ -133,22 +134,22 @@ func (c *replClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextF } // Optional/UNSTABLE terminal methods: implement as no-ops for example -func (c *replClient) CreateTerminal(params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { +func (c *replClient) CreateTerminal(ctx context.Context, params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { fmt.Printf("[Client] CreateTerminal: %v\n", params) return acp.CreateTerminalResponse{TerminalId: "term-1"}, nil } -func (c *replClient) TerminalOutput(params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { +func (c *replClient) TerminalOutput(ctx context.Context, params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { fmt.Printf("[Client] TerminalOutput: %v\n", params) return acp.TerminalOutputResponse{Output: "", Truncated: false}, nil } -func (c *replClient) ReleaseTerminal(params acp.ReleaseTerminalRequest) error { +func (c *replClient) ReleaseTerminal(ctx context.Context, params acp.ReleaseTerminalRequest) error { fmt.Printf("[Client] ReleaseTerminal: %v\n", params) return nil } -func (c *replClient) WaitForTerminalExit(params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { +func (c *replClient) WaitForTerminalExit(ctx context.Context, params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { fmt.Printf("[Client] WaitForTerminalExit: %v\n", params) return acp.WaitForTerminalExitResponse{}, nil } @@ -158,7 +159,9 @@ func main() { flag.Parse() // Invoke Claude Code via npx - cmd := exec.Command("npx", "-y", "@zed-industries/claude-code-acp") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cmd := exec.CommandContext(ctx, "npx", "-y", "@zed-industries/claude-code-acp") cmd.Stderr = os.Stderr stdin, err := cmd.StdinPipe() if err != nil { @@ -180,7 +183,7 @@ func main() { conn := acp.NewClientSideConnection(client, stdin, stdout) // Initialize - initResp, err := conn.Initialize(acp.InitializeRequest{ + initResp, err := conn.Initialize(ctx, acp.InitializeRequest{ ProtocolVersion: acp.ProtocolVersionNumber, ClientCapabilities: acp.ClientCapabilities{Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}, }) @@ -200,7 +203,7 @@ func main() { fmt.Printf("✅ Connected to Claude Code (protocol v%v)\n", initResp.ProtocolVersion) // New session - newSess, err := conn.NewSession(acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) + newSess, err := conn.NewSession(ctx, acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) if err != nil { if re, ok := err.(*acp.RequestError); ok { if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { @@ -229,14 +232,14 @@ func main() { } switch line { case ":exit", ":quit": - _ = cmd.Process.Kill() + cancel() return case ":cancel": - _ = conn.Cancel(acp.CancelNotification(newSess)) + _ = conn.Cancel(ctx, acp.CancelNotification(newSess)) continue } // Send prompt and wait for completion while streaming updates are printed via SessionUpdate - if _, err := conn.Prompt(acp.PromptRequest{ + if _, err := conn.Prompt(ctx, acp.PromptRequest{ SessionId: newSess.SessionId, Prompt: []acp.ContentBlock{acp.TextBlock(line)}, }); err != nil { diff --git a/go/example/client/main.go b/go/example/client/main.go index 5ca0d40..670821e 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "encoding/json" "fmt" "os" @@ -19,7 +20,7 @@ var ( _ acp.ClientTerminal = (*exampleClient)(nil) ) -func (e *exampleClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { +func (e *exampleClient) RequestPermission(ctx context.Context, params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { fmt.Printf("\n🔐 Permission requested: %s\n", params.ToolCall.Title) fmt.Println("\nOptions:") for i, opt := range params.Options { @@ -43,7 +44,7 @@ func (e *exampleClient) RequestPermission(params acp.RequestPermissionRequest) ( } } -func (e *exampleClient) SessionUpdate(params acp.SessionNotification) error { +func (e *exampleClient) SessionUpdate(ctx context.Context, params acp.SessionNotification) error { u := params.Update switch { case u.AgentMessageChunk != nil: @@ -83,7 +84,7 @@ func displayUpdateKind(u acp.SessionUpdate) string { } } -func (e *exampleClient) WriteTextFile(params acp.WriteTextFileRequest) error { +func (e *exampleClient) WriteTextFile(ctx context.Context, params acp.WriteTextFileRequest) error { if !filepath.IsAbs(params.Path) { return fmt.Errorf("path must be absolute: %s", params.Path) } @@ -100,7 +101,7 @@ func (e *exampleClient) WriteTextFile(params acp.WriteTextFileRequest) error { return nil } -func (e *exampleClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { +func (e *exampleClient) ReadTextFile(ctx context.Context, params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { if !filepath.IsAbs(params.Path) { return acp.ReadTextFileResponse{}, fmt.Errorf("path must be absolute: %s", params.Path) } @@ -129,34 +130,37 @@ func (e *exampleClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTe } // Optional/UNSTABLE terminal methods: implement as no-ops for example -func (e *exampleClient) CreateTerminal(params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { +func (e *exampleClient) CreateTerminal(ctx context.Context, params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { fmt.Printf("[Client] CreateTerminal: %v\n", params) return acp.CreateTerminalResponse{TerminalId: "term-1"}, nil } -func (e *exampleClient) TerminalOutput(params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { +func (e *exampleClient) TerminalOutput(ctx context.Context, params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { fmt.Printf("[Client] TerminalOutput: %v\n", params) return acp.TerminalOutputResponse{Output: "", Truncated: false}, nil } -func (e *exampleClient) ReleaseTerminal(params acp.ReleaseTerminalRequest) error { +func (e *exampleClient) ReleaseTerminal(ctx context.Context, params acp.ReleaseTerminalRequest) error { fmt.Printf("[Client] ReleaseTerminal: %v\n", params) return nil } -func (e *exampleClient) WaitForTerminalExit(params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { +func (e *exampleClient) WaitForTerminalExit(ctx context.Context, params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { fmt.Printf("[Client] WaitForTerminalExit: %v\n", params) return acp.WaitForTerminalExitResponse{}, nil } func main() { // If args provided, treat them as agent program + args. Otherwise run the Go agent example. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var cmd *exec.Cmd if len(os.Args) > 1 { - cmd = exec.Command(os.Args[1], os.Args[2:]...) + cmd = exec.CommandContext(ctx, os.Args[1], os.Args[2:]...) } else { // Assumes running from the go/ directory; if not, adjust path accordingly. - cmd = exec.Command("go", "run", "./example/agent") + cmd = exec.CommandContext(ctx, "go", "run", "./example/agent") } cmd.Stderr = os.Stderr cmd.Stdout = nil @@ -173,7 +177,7 @@ func main() { conn := acp.NewClientSideConnection(client, stdin, stdout) // Initialize - initResp, err := conn.Initialize(acp.InitializeRequest{ + initResp, err := conn.Initialize(ctx, acp.InitializeRequest{ ProtocolVersion: acp.ProtocolVersionNumber, ClientCapabilities: acp.ClientCapabilities{ Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}, @@ -196,7 +200,7 @@ func main() { fmt.Printf("✅ Connected to agent (protocol v%v)\n", initResp.ProtocolVersion) // New session - newSess, err := conn.NewSession(acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) + newSess, err := conn.NewSession(ctx, acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) if err != nil { if re, ok := err.(*acp.RequestError); ok { if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { @@ -215,7 +219,7 @@ func main() { fmt.Print(" ") // Send prompt - if _, err := conn.Prompt(acp.PromptRequest{ + if _, err := conn.Prompt(ctx, acp.PromptRequest{ SessionId: newSess.SessionId, Prompt: []acp.ContentBlock{acp.TextBlock("Hello, agent!")}, }); err != nil { diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index adbacb6..cf317d7 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "encoding/json" "flag" "fmt" @@ -25,7 +26,7 @@ var ( _ acp.ClientTerminal = (*replClient)(nil) ) -func (c *replClient) RequestPermission(params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { +func (c *replClient) RequestPermission(ctx context.Context, params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { if c.autoApprove { // Prefer an allow option if present; otherwise choose the first option. for _, o := range params.Options { @@ -62,7 +63,7 @@ func (c *replClient) RequestPermission(params acp.RequestPermissionRequest) (acp } } -func (c *replClient) SessionUpdate(params acp.SessionNotification) error { +func (c *replClient) SessionUpdate(ctx context.Context, params acp.SessionNotification) error { u := params.Update switch { case u.AgentMessageChunk != nil: @@ -91,7 +92,7 @@ func (c *replClient) SessionUpdate(params acp.SessionNotification) error { return nil } -func (c *replClient) WriteTextFile(params acp.WriteTextFileRequest) error { +func (c *replClient) WriteTextFile(ctx context.Context, params acp.WriteTextFileRequest) error { if !filepath.IsAbs(params.Path) { return fmt.Errorf("path must be absolute: %s", params.Path) } @@ -108,7 +109,7 @@ func (c *replClient) WriteTextFile(params acp.WriteTextFileRequest) error { return nil } -func (c *replClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { +func (c *replClient) ReadTextFile(ctx context.Context, params acp.ReadTextFileRequest) (acp.ReadTextFileResponse, error) { if !filepath.IsAbs(params.Path) { return acp.ReadTextFileResponse{}, fmt.Errorf("path must be absolute: %s", params.Path) } @@ -136,22 +137,22 @@ func (c *replClient) ReadTextFile(params acp.ReadTextFileRequest) (acp.ReadTextF } // Optional/UNSTABLE terminal methods: implement as no-ops for example -func (c *replClient) CreateTerminal(params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { +func (c *replClient) CreateTerminal(ctx context.Context, params acp.CreateTerminalRequest) (acp.CreateTerminalResponse, error) { fmt.Printf("[Client] CreateTerminal: %v\n", params) return acp.CreateTerminalResponse{TerminalId: "term-1"}, nil } -func (c *replClient) TerminalOutput(params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { +func (c *replClient) TerminalOutput(ctx context.Context, params acp.TerminalOutputRequest) (acp.TerminalOutputResponse, error) { fmt.Printf("[Client] TerminalOutput: %v\n", params) return acp.TerminalOutputResponse{Output: "", Truncated: false}, nil } -func (c *replClient) ReleaseTerminal(params acp.ReleaseTerminalRequest) error { +func (c *replClient) ReleaseTerminal(ctx context.Context, params acp.ReleaseTerminalRequest) error { fmt.Printf("[Client] ReleaseTerminal: %v\n", params) return nil } -func (c *replClient) WaitForTerminalExit(params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { +func (c *replClient) WaitForTerminalExit(ctx context.Context, params acp.WaitForTerminalExitRequest) (acp.WaitForTerminalExitResponse, error) { fmt.Printf("[Client] WaitForTerminalExit: %v\n", params) return acp.WaitForTerminalExitResponse{}, nil } @@ -175,7 +176,10 @@ func main() { args = append(args, "--debug") } - cmd := exec.Command(*binary, args...) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cmd := exec.CommandContext(ctx, *binary, args...) cmd.Stderr = os.Stderr stdin, err := cmd.StdinPipe() if err != nil { @@ -197,7 +201,7 @@ func main() { conn := acp.NewClientSideConnection(client, stdin, stdout) // Initialize - initResp, err := conn.Initialize(acp.InitializeRequest{ + initResp, err := conn.Initialize(ctx, acp.InitializeRequest{ ProtocolVersion: acp.ProtocolVersionNumber, ClientCapabilities: acp.ClientCapabilities{ Fs: acp.FileSystemCapability{ReadTextFile: true, WriteTextFile: true}, @@ -220,7 +224,10 @@ func main() { fmt.Printf("✅ Connected to Gemini (protocol v%v)\n", initResp.ProtocolVersion) // New session - newSess, err := conn.NewSession(acp.NewSessionRequest{Cwd: mustCwd(), McpServers: []acp.McpServer{}}) + newSess, err := conn.NewSession(ctx, acp.NewSessionRequest{ + Cwd: mustCwd(), + McpServers: []acp.McpServer{}, + }) if err != nil { if re, ok := err.(*acp.RequestError); ok { if b, mErr := json.MarshalIndent(re, "", " "); mErr == nil { @@ -249,14 +256,14 @@ func main() { } switch line { case ":exit", ":quit": - _ = cmd.Process.Kill() + cancel() return case ":cancel": - _ = conn.Cancel(acp.CancelNotification(newSess)) + _ = conn.Cancel(ctx, acp.CancelNotification(newSess)) continue } // Send prompt and wait for completion while streaming updates are printed via SessionUpdate - if _, err := conn.Prompt(acp.PromptRequest{ + if _, err := conn.Prompt(ctx, acp.PromptRequest{ SessionId: newSess.SessionId, Prompt: []acp.ContentBlock{acp.TextBlock(line)}, }); err != nil { diff --git a/go/example_agent_test.go b/go/example_agent_test.go index 0f28bc9..8c5fa1b 100644 --- a/go/example_agent_test.go +++ b/go/example_agent_test.go @@ -1,6 +1,7 @@ package acp import ( + "context" "os" ) @@ -11,21 +12,21 @@ type agentExample struct{ conn *AgentSideConnection } func (a *agentExample) SetAgentConnection(c *AgentSideConnection) { a.conn = c } -func (agentExample) Authenticate(AuthenticateRequest) error { return nil } -func (agentExample) Initialize(InitializeRequest) (InitializeResponse, error) { +func (agentExample) Authenticate(ctx context.Context, _ AuthenticateRequest) error { return nil } +func (agentExample) Initialize(ctx context.Context, _ InitializeRequest) (InitializeResponse, error) { return InitializeResponse{ ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, }, nil } -func (agentExample) Cancel(CancelNotification) error { return nil } -func (agentExample) NewSession(NewSessionRequest) (NewSessionResponse, error) { +func (agentExample) Cancel(ctx context.Context, _ CancelNotification) error { return nil } +func (agentExample) NewSession(ctx context.Context, _ NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{SessionId: SessionId("sess_demo")}, nil } -func (a *agentExample) Prompt(p PromptRequest) (PromptResponse, error) { +func (a *agentExample) Prompt(ctx context.Context, p PromptRequest) (PromptResponse, error) { // Stream an initial agent message. - _ = a.conn.SessionUpdate(SessionNotification{ + _ = a.conn.SessionUpdate(ctx, SessionNotification{ SessionId: p.SessionId, Update: SessionUpdate{ AgentMessageChunk: &SessionUpdateAgentMessageChunk{ @@ -35,7 +36,7 @@ func (a *agentExample) Prompt(p PromptRequest) (PromptResponse, error) { }) // Announce a tool call. - _ = a.conn.SessionUpdate(SessionNotification{ + _ = a.conn.SessionUpdate(ctx, SessionNotification{ SessionId: p.SessionId, Update: SessionUpdate{ToolCall: &SessionUpdateToolCall{ ToolCallId: ToolCallId("call_1"), @@ -48,7 +49,7 @@ func (a *agentExample) Prompt(p PromptRequest) (PromptResponse, error) { }) // Ask the client for permission to proceed with the change. - resp, _ := a.conn.RequestPermission(RequestPermissionRequest{ + resp, _ := a.conn.RequestPermission(ctx, RequestPermissionRequest{ SessionId: p.SessionId, ToolCall: ToolCallUpdate{ ToolCallId: ToolCallId("call_1"), @@ -66,7 +67,7 @@ func (a *agentExample) Prompt(p PromptRequest) (PromptResponse, error) { if resp.Outcome.Selected != nil && string(resp.Outcome.Selected.OptionId) == "allow" { // Mark tool call completed and stream a final message. - _ = a.conn.SessionUpdate(SessionNotification{ + _ = a.conn.SessionUpdate(ctx, SessionNotification{ SessionId: p.SessionId, Update: SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ ToolCallId: ToolCallId("call_1"), @@ -74,7 +75,7 @@ func (a *agentExample) Prompt(p PromptRequest) (PromptResponse, error) { RawOutput: map[string]any{"success": true}, }}, }) - _ = a.conn.SessionUpdate(SessionNotification{ + _ = a.conn.SessionUpdate(ctx, SessionNotification{ SessionId: p.SessionId, Update: SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: TextBlock("Done.")}}, }) diff --git a/go/example_client_test.go b/go/example_client_test.go index 5fe56f5..c4fad1c 100644 --- a/go/example_client_test.go +++ b/go/example_client_test.go @@ -1,6 +1,7 @@ package acp import ( + "context" "fmt" "os" "os/exec" @@ -13,7 +14,7 @@ import ( // permission option. type clientExample struct{} -func (clientExample) RequestPermission(p RequestPermissionRequest) (RequestPermissionResponse, error) { +func (clientExample) RequestPermission(ctx context.Context, p RequestPermissionRequest) (RequestPermissionResponse, error) { if len(p.Options) == 0 { return RequestPermissionResponse{ Outcome: RequestPermissionOutcome{ @@ -28,7 +29,7 @@ func (clientExample) RequestPermission(p RequestPermissionRequest) (RequestPermi }, nil } -func (clientExample) SessionUpdate(n SessionNotification) error { +func (clientExample) SessionUpdate(ctx context.Context, n SessionNotification) error { u := n.Update switch { case u.AgentMessageChunk != nil: @@ -46,7 +47,7 @@ func (clientExample) SessionUpdate(n SessionNotification) error { return nil } -func (clientExample) WriteTextFile(p WriteTextFileRequest) error { +func (clientExample) WriteTextFile(ctx context.Context, p WriteTextFileRequest) error { if !filepath.IsAbs(p.Path) { return fmt.Errorf("path must be absolute: %s", p.Path) } @@ -56,7 +57,7 @@ func (clientExample) WriteTextFile(p WriteTextFileRequest) error { return os.WriteFile(p.Path, []byte(p.Content), 0o644) } -func (clientExample) ReadTextFile(p ReadTextFileRequest) (ReadTextFileResponse, error) { +func (clientExample) ReadTextFile(ctx context.Context, p ReadTextFileRequest) (ReadTextFileResponse, error) { if !filepath.IsAbs(p.Path) { return ReadTextFileResponse{}, fmt.Errorf("path must be absolute: %s", p.Path) } @@ -88,13 +89,14 @@ func (clientExample) ReadTextFile(p ReadTextFileRequest) (ReadTextFileResponse, // Example_client launches the Go agent example, negotiates protocol, // opens a session, and sends a simple prompt. func Example_client() { + ctx := context.Background() cmd := exec.Command("go", "run", "./example/agent") stdin, _ := cmd.StdinPipe() stdout, _ := cmd.StdoutPipe() _ = cmd.Start() conn := NewClientSideConnection(clientExample{}, stdin, stdout) - _, _ = conn.Initialize(InitializeRequest{ + _, _ = conn.Initialize(ctx, InitializeRequest{ ProtocolVersion: ProtocolVersionNumber, ClientCapabilities: ClientCapabilities{ Fs: FileSystemCapability{ @@ -104,11 +106,11 @@ func Example_client() { Terminal: true, }, }) - sess, _ := conn.NewSession(NewSessionRequest{ + sess, _ := conn.NewSession(ctx, NewSessionRequest{ Cwd: "/", McpServers: []McpServer{}, }) - _, _ = conn.Prompt(PromptRequest{ + _, _ = conn.Prompt(ctx, PromptRequest{ SessionId: sess.SessionId, Prompt: []ContentBlock{TextBlock("Hello, agent!")}, }) diff --git a/go/example_gemini_test.go b/go/example_gemini_test.go index 8487c3f..4a7d1c9 100644 --- a/go/example_gemini_test.go +++ b/go/example_gemini_test.go @@ -1,6 +1,7 @@ package acp import ( + "context" "fmt" "os/exec" ) @@ -9,14 +10,14 @@ import ( // selects the first permission option. File ops are no-ops here. type geminiClient struct{} -func (geminiClient) RequestPermission(p RequestPermissionRequest) (RequestPermissionResponse, error) { +func (geminiClient) RequestPermission(ctx context.Context, p RequestPermissionRequest) (RequestPermissionResponse, error) { if len(p.Options) == 0 { return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Cancelled: &RequestPermissionOutcomeCancelled{}}}, nil } return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: p.Options[0].OptionId}}}, nil } -func (geminiClient) SessionUpdate(n SessionNotification) error { +func (geminiClient) SessionUpdate(ctx context.Context, n SessionNotification) error { if n.Update.AgentMessageChunk != nil { c := n.Update.AgentMessageChunk.Content if c.Type == "text" && c.Text != nil { @@ -26,21 +27,22 @@ func (geminiClient) SessionUpdate(n SessionNotification) error { return nil } -func (geminiClient) ReadTextFile(ReadTextFileRequest) (ReadTextFileResponse, error) { +func (geminiClient) ReadTextFile(ctx context.Context, _ ReadTextFileRequest) (ReadTextFileResponse, error) { return ReadTextFileResponse{}, nil } -func (geminiClient) WriteTextFile(WriteTextFileRequest) error { return nil } +func (geminiClient) WriteTextFile(ctx context.Context, _ WriteTextFileRequest) error { return nil } // Example_gemini connects to a Gemini CLI speaking ACP over stdio, // then initializes, opens a session, and sends a prompt. func Example_gemini() { + ctx := context.Background() cmd := exec.Command("gemini", "--experimental-acp") stdin, _ := cmd.StdinPipe() stdout, _ := cmd.StdoutPipe() _ = cmd.Start() conn := NewClientSideConnection(geminiClient{}, stdin, stdout) - _, _ = conn.Initialize(InitializeRequest{ + _, _ = conn.Initialize(ctx, InitializeRequest{ ProtocolVersion: ProtocolVersionNumber, ClientCapabilities: ClientCapabilities{ Fs: FileSystemCapability{ @@ -50,11 +52,11 @@ func Example_gemini() { Terminal: true, }, }) - sess, _ := conn.NewSession(NewSessionRequest{ + sess, _ := conn.NewSession(ctx, NewSessionRequest{ Cwd: "/", McpServers: []McpServer{}, }) - _, _ = conn.Prompt(PromptRequest{ + _, _ = conn.Prompt(ctx, PromptRequest{ SessionId: sess.SessionId, Prompt: []ContentBlock{TextBlock("list files")}, }) diff --git a/go/types_gen.go b/go/types_gen.go index 549ecc4..b0e2b17 100644 --- a/go/types_gen.go +++ b/go/types_gen.go @@ -3,6 +3,7 @@ package acp import ( + "context" "encoding/json" "fmt" ) @@ -1083,28 +1084,28 @@ func (v *WriteTextFileRequest) Validate() error { } type Agent interface { - Authenticate(params AuthenticateRequest) error - Initialize(params InitializeRequest) (InitializeResponse, error) - Cancel(params CancelNotification) error - NewSession(params NewSessionRequest) (NewSessionResponse, error) - Prompt(params PromptRequest) (PromptResponse, error) + Authenticate(ctx context.Context, params AuthenticateRequest) error + Initialize(ctx context.Context, params InitializeRequest) (InitializeResponse, error) + Cancel(ctx context.Context, params CancelNotification) error + NewSession(ctx context.Context, params NewSessionRequest) (NewSessionResponse, error) + Prompt(ctx context.Context, params PromptRequest) (PromptResponse, error) } // AgentLoader defines optional support for loading sessions. Implement and advertise the capability to enable 'session/load'. type AgentLoader interface { - LoadSession(params LoadSessionRequest) error + LoadSession(ctx context.Context, params LoadSessionRequest) error } type Client interface { - ReadTextFile(params ReadTextFileRequest) (ReadTextFileResponse, error) - WriteTextFile(params WriteTextFileRequest) error - RequestPermission(params RequestPermissionRequest) (RequestPermissionResponse, error) - SessionUpdate(params SessionNotification) error + ReadTextFile(ctx context.Context, params ReadTextFileRequest) (ReadTextFileResponse, error) + WriteTextFile(ctx context.Context, params WriteTextFileRequest) error + RequestPermission(ctx context.Context, params RequestPermissionRequest) (RequestPermissionResponse, error) + SessionUpdate(ctx context.Context, params SessionNotification) error } // ClientTerminal defines terminal-related experimental methods (x-docs-ignore). Implement and advertise 'terminal: true' to enable 'terminal/*'. type ClientTerminal interface { - CreateTerminal(params CreateTerminalRequest) (CreateTerminalResponse, error) - TerminalOutput(params TerminalOutputRequest) (TerminalOutputResponse, error) - ReleaseTerminal(params ReleaseTerminalRequest) error - WaitForTerminalExit(params WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) + CreateTerminal(ctx context.Context, params CreateTerminalRequest) (CreateTerminalResponse, error) + TerminalOutput(ctx context.Context, params TerminalOutputRequest) (TerminalOutputResponse, error) + ReleaseTerminal(ctx context.Context, params ReleaseTerminalRequest) error + WaitForTerminalExit(ctx context.Context, params WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) } From 977cdb3ccc2f168571b64397bd3f62a1af0c69cf Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Tue, 2 Sep 2025 09:14:29 +0200 Subject: [PATCH 11/22] refactor: convert fields to pointer types and update union generation Change-Id: Iaf75a10da7807d196dde16ec744c72182d767dd2 Signed-off-by: Thomas Kosiewski --- go/acp_test.go | 27 +- go/cmd/generate/internal/emit/helpers.go | 14 +- go/cmd/generate/internal/emit/types.go | 242 +++++++++--- go/constants_gen.go | 1 + go/example/agent/main.go | 3 +- go/example/claude-code/main.go | 22 +- go/example/client/main.go | 18 +- go/example/gemini/main.go | 20 +- go/example_agent_test.go | 2 +- go/example_client_test.go | 15 +- go/helpers_gen.go | 29 +- go/types_gen.go | 450 +++++++++++++++++------ 12 files changed, 614 insertions(+), 229 deletions(-) diff --git a/go/acp_test.go b/go/acp_test.go index 987d56c..06cd606 100644 --- a/go/acp_test.go +++ b/go/acp_test.go @@ -237,7 +237,11 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) { return ReadTextFileResponse{Content: "test content"}, nil }, RequestPermissionFunc: func(_ context.Context, req RequestPermissionRequest) (RequestPermissionResponse, error) { - push("requestPermission called: " + req.ToolCall.Title) + title := "" + if req.ToolCall.Title != nil { + title = *req.ToolCall.Title + } + push("requestPermission called: " + title) return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil }, SessionUpdateFunc: func(context.Context, SessionNotification) error { return nil }, @@ -280,14 +284,11 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) { if _, err := as.RequestPermission(context.Background(), RequestPermissionRequest{ SessionId: "test-session", ToolCall: ToolCallUpdate{ - Title: "Execute command", + Title: Ptr("Execute command"), Kind: ptr(ToolKindExecute), Status: ptr(ToolCallStatusPending), ToolCallId: "tool-123", - Content: []ToolCallContent{{ - Type: "content", - Content: &ContentBlock{Type: "text", Text: &TextContent{Text: "ls -la"}}, - }}, + Content: []ToolCallContent{ToolContent(TextBlock("ls -la"))}, }, Options: []PermissionOption{ {Kind: "allow_once", Name: "Allow", OptionId: "allow"}, @@ -403,7 +404,19 @@ func TestConnectionHandlesInitialize(t *testing.T) { }, c2aW, a2cR) _ = NewAgentSideConnection(agentFuncs{ InitializeFunc: func(_ context.Context, p InitializeRequest) (InitializeResponse, error) { - return InitializeResponse{ProtocolVersion: p.ProtocolVersion, AgentCapabilities: AgentCapabilities{LoadSession: true}, AuthMethods: []AuthMethod{{Id: "oauth", Name: "OAuth", Description: "Authenticate with OAuth"}}}, nil + return InitializeResponse{ + ProtocolVersion: p.ProtocolVersion, + AgentCapabilities: AgentCapabilities{ + LoadSession: true, + }, + AuthMethods: []AuthMethod{ + { + Id: "oauth", + Name: "OAuth", + Description: Ptr("Authenticate with OAuth"), + }, + }, + }, nil }, NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{SessionId: "test-session"}, nil diff --git a/go/cmd/generate/internal/emit/helpers.go b/go/cmd/generate/internal/emit/helpers.go index d1a3429..a9c1da0 100644 --- a/go/cmd/generate/internal/emit/helpers.go +++ b/go/cmd/generate/internal/emit/helpers.go @@ -64,10 +64,8 @@ func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { // ToolCall content helpers f.Comment("ToolContent wraps a content block as tool-call content.") f.Func().Id("ToolContent").Params(Id("block").Id("ContentBlock")).Id("ToolCallContent").Block( - Var().Id("b").Id("ContentBlock").Op("=").Id("block"), Return(Id("ToolCallContent").Values(Dict{ - Id("Type"): Lit("content"), - Id("Content"): Op("&").Id("b"), + Id("Content"): Op("&").Id("ToolCallContentContent").Values(Dict{Id("Content"): Id("block"), Id("Type"): Lit("content")}), })), ) f.Line() @@ -79,12 +77,7 @@ func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { Id("o").Op("=").Op("&").Id("oldText").Index(Lit(0)), ), Return(Id("ToolCallContent").Values(Dict{ - Id("Type"): Lit("diff"), - Id("Diff"): Op("&").Id("DiffContent").Values(Dict{ - Id("Path"): Id("path"), - Id("NewText"): Id("newText"), - Id("OldText"): Id("o"), - }), + Id("Diff"): Op("&").Id("ToolCallContentDiff").Values(Dict{Id("Path"): Id("path"), Id("NewText"): Id("newText"), Id("OldText"): Id("o"), Id("Type"): Lit("diff")}), })), ) f.Line() @@ -92,8 +85,7 @@ func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { f.Comment("ToolTerminalRef constructs a terminal reference tool-call content.") f.Func().Id("ToolTerminalRef").Params(Id("terminalId").String()).Id("ToolCallContent").Block( Return(Id("ToolCallContent").Values(Dict{ - Id("Type"): Lit("terminal"), - Id("Terminal"): Op("&").Id("TerminalRef").Values(Dict{Id("TerminalId"): Id("terminalId")}), + Id("Terminal"): Op("&").Id("ToolCallContentTerminal").Values(Dict{Id("TerminalId"): Id("terminalId"), Id("Type"): Lit("terminal")}), })), ) f.Line() diff --git a/go/cmd/generate/internal/emit/types.go b/go/cmd/generate/internal/emit/types.go index fce7502..b56af4e 100644 --- a/go/cmd/generate/internal/emit/types.go +++ b/go/cmd/generate/internal/emit/types.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "slices" "sort" "strings" @@ -62,14 +63,18 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { f.Line() case name == "ContentBlock": emitContentBlockJen(f) - case name == "ToolCallContent": - emitToolCallContentJen(f) - case name == "EmbeddedResourceResource": - emitEmbeddedResourceResourceJen(f) - case name == "RequestPermissionOutcome": - emitRequestPermissionOutcomeJen(f) case name == "SessionUpdate": emitSessionUpdateJen(f) + case len(def.AnyOf) > 0 && !def.DocsIgnore: + emitAnyOfUnionJen(f, name, def) + case len(def.OneOf) > 0 && !isStringConstUnion(def) && !def.DocsIgnore: + // Generic union generation for non-enum oneOf + // Reuse same path as anyOf by treating oneOf as anyOf here + // Temporarily map OneOf into def.AnyOf for emission + tmp := *def + tmp.AnyOf = def.OneOf + tmp.OneOf = nil + emitAnyOfUnionJen(f, name, &tmp) case ir.PrimaryType(def) == "object" && len(def.Properties) > 0: st := []Code{} req := map[string]struct{}{} @@ -105,7 +110,7 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { } // validators for selected types - if strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification") || name == "ContentBlock" || name == "ToolCallContent" || name == "SessionUpdate" || name == "ToolCallUpdate" { + if strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification") || name == "ContentBlock" || name == "SessionUpdate" || name == "ToolCallUpdate" { emitValidateJen(f, name, def) } } @@ -261,7 +266,7 @@ func emitValidateJen(f *File, name string, def *load.Definition) { return } if def != nil && ir.PrimaryType(def) == "object" { - if !(strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification")) { + if !strings.HasSuffix(name, "Request") && !strings.HasSuffix(name, "Response") && !strings.HasSuffix(name, "Notification") { return } f.Func().Params(Id("v").Op("*").Id(name)).Id("Validate").Params().Params(Error()).BlockFunc(func(g *Group) { @@ -272,13 +277,7 @@ func emitValidateJen(f *File, name string, def *load.Definition) { sort.Strings(pkeys) for _, propName := range pkeys { pDef := def.Properties[propName] - required := false - for _, r := range def.Required { - if r == propName { - required = true - break - } - } + required := slices.Contains(def.Required, propName) field := util.ToExportedField(propName) if required { switch ir.PrimaryType(pDef) { @@ -356,6 +355,28 @@ func jenTypeForOptional(d *load.Definition) Code { if len(list) == 0 { list = d.OneOf } + // Case: property type is a union like ["string","null"] + if arr, ok := d.Type.([]any); ok && len(arr) == 2 { + var other string + for _, v := range arr { + if s, ok2 := v.(string); ok2 { + if s == "null" { + continue + } + other = s + } + } + switch other { + case "string": + return Op("*").String() + case "integer": + return Op("*").Int() + case "number": + return Op("*").Float64() + case "boolean": + return Op("*").Bool() + } + } if len(list) == 2 { var nonNull *load.Definition for _, e := range list { @@ -572,49 +593,156 @@ func emitEmbeddedResourceResourceJen(f *File) { f.Line() } -func emitRequestPermissionOutcomeJen(f *File) { - f.Type().Id("RequestPermissionOutcomeCancelled").Struct() - f.Type().Id("RequestPermissionOutcomeSelected").Struct( - Id("OptionId").Id("PermissionOptionId").Tag(map[string]string{"json": "optionId"}), - ) +// emitAvailableCommandInputJen generates a concrete variant type for anyOf and a thin union wrapper +// that supports JSON unmarshal by probing object shape. Currently the schema defines one variant +// (title: UnstructuredCommandInput) with a required 'hint' field. +func emitAnyOfUnionJen(f *File, name string, def *load.Definition) { + // Collect variant names and generate inline structs for object variants if needed + type variantInfo struct { + fieldName string + typeName string + required []string + isObject bool + consts map[string]any + } + variants := []variantInfo{} + for idx, v := range def.AnyOf { + if v == nil { + continue + } + tname := v.Title + if tname == "" { + if v.Ref != "" { + // derive from ref path + if strings.HasPrefix(v.Ref, "#/$defs/") { + tname = v.Ref[len("#/$defs/"):] + } + } else { + // Derive from const outcome/type if present + if out, ok := v.Properties["outcome"]; ok && out != nil && out.Const != nil { + s := fmt.Sprint(out.Const) + tname = name + util.ToExportedField(s) + } else if typ, ok2 := v.Properties["type"]; ok2 && typ != nil && typ.Const != nil { + s := fmt.Sprint(typ.Const) + tname = name + util.ToExportedField(s) + } else { + tname = name + fmt.Sprintf("Variant%d", idx+1) + } + } + } + fieldName := tname + if out, ok := v.Properties["outcome"]; ok && out != nil && out.Const != nil { + s := fmt.Sprint(out.Const) + fieldName = util.ToExportedField(s) + } else if typ, ok2 := v.Properties["type"]; ok2 && typ != nil && typ.Const != nil { + s := fmt.Sprint(typ.Const) + fieldName = util.ToExportedField(s) + } + // If this variant is an inline object, generate its struct + isObj := len(v.Properties) > 0 + if isObj && v.Ref == "" { + st := []Code{} + req := map[string]struct{}{} + for _, r := range v.Required { + req[r] = struct{}{} + } + pkeys := make([]string, 0, len(v.Properties)) + for pk := range v.Properties { + pkeys = append(pkeys, pk) + } + sort.Strings(pkeys) + // Variant doc comment + if v.Description != "" { + f.Comment(util.SanitizeComment(v.Description)) + } + for _, pk := range pkeys { + pDef := v.Properties[pk] + field := util.ToExportedField(pk) + if pDef.Description != "" { + st = append(st, Comment(util.SanitizeComment(pDef.Description))) + } + tag := pk + if _, ok := req[pk]; !ok { + tag = pk + ",omitempty" + } + st = append(st, Id(field).Add(jenTypeForOptional(pDef)).Tag(map[string]string{"json": tag})) + } + f.Type().Id(tname).Struct(st...) + f.Line() + } + // Collect const properties for detection + consts := map[string]any{} + for pk, pd := range v.Properties { + if pd != nil && pd.Const != nil { + consts[pk] = pd.Const + } + } + variants = append(variants, variantInfo{fieldName: fieldName, typeName: tname, required: v.Required, isObject: isObj, consts: consts}) + } + // Union wrapper + st := []Code{} + for _, vi := range variants { + st = append(st, Id(vi.fieldName).Op("*").Id(vi.typeName).Tag(map[string]string{"json": "-"})) + } + f.Type().Id(name).Struct(st...) f.Line() - f.Type().Id("RequestPermissionOutcome").Struct( - Id("Cancelled").Op("*").Id("RequestPermissionOutcomeCancelled").Tag(map[string]string{"json": "-"}), - Id("Selected").Op("*").Id("RequestPermissionOutcomeSelected").Tag(map[string]string{"json": "-"}), - ) - f.Func().Params(Id("o").Op("*").Id("RequestPermissionOutcome")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Var().Id("outcome").String(), - If(List(Id("v"), Id("ok")).Op(":=").Id("m").Index(Lit("outcome")), Id("ok")).Block( - Qual("encoding/json", "Unmarshal").Call(Id("v"), Op("&").Id("outcome")), - ), - Switch(Id("outcome")).Block( - Case(Lit("cancelled")).Block( - Id("o").Dot("Cancelled").Op("=").Op("&").Id("RequestPermissionOutcomeCancelled").Values(), - Return(Nil()), - ), - Case(Lit("selected")).Block( - Var().Id("v2").Id("RequestPermissionOutcomeSelected"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v2")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("o").Dot("Selected").Op("=").Op("&").Id("v2"), - Return(Nil()), - ), - ), - Return(Nil()), - ) - f.Func().Params(Id("o").Id("RequestPermissionOutcome")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( - If(Id("o").Dot("Cancelled").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{Lit("outcome"): Lit("cancelled")}))), - ), - If(Id("o").Dot("Selected").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("optionId"): Id("o").Dot("Selected").Dot("OptionId"), - Lit("outcome"): Lit("selected"), - }))), - ), - Return(Index().Byte().Values(), Nil()), - ) + // Unmarshal: prefer required-field presence checks for object variants, then fallback to try-unmarshal for all + f.Func().Params(Id("u").Op("*").Id(name)).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().BlockFunc(func(g *Group) { + g.Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage") + g.If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))) + // Try required-field detection for object variants + for _, vi := range variants { + if vi.isObject && len(vi.required) > 0 { + stmts := []Code{ + Var().Id("v").Id(vi.typeName), + Var().Id("match").Bool().Op("=").Lit(true), + } + for _, rk := range vi.required { + stmts = append(stmts, If(List(Id("_"), Id("ok")).Op(":=").Id("m").Index(Lit(rk)), Op("!").Id("ok")).Block(Id("match").Op("=").Lit(false))) + } + // Check const-valued fields + for ck, cv := range vi.consts { + // read m[ck] and compare to const value (stringify for simplicity) + stmts = append(stmts, + Var().Id("raw").Qual("encoding/json", "RawMessage"), Var().Id("ok").Bool(), + List(Id("raw"), Id("ok")).Op("=").Id("m").Index(Lit(ck)), + If(Op("!").Id("ok")).Block(Id("match").Op("=").Lit(false)), + If(Id("ok")).Block( + Var().Id("tmp").Any(), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("raw"), Op("&").Id("tmp")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + If(Qual("fmt", "Sprint").Call(Id("tmp")).Op("!=").Qual("fmt", "Sprint").Call(Lit(cv))).Block(Id("match").Op("=").Lit(false)), + ), + ) + } + stmts = append(stmts, If(Id("match")).Block( + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), + Id("u").Dot(vi.fieldName).Op("=").Op("&").Id("v"), + Return(Nil()), + )) + g.Block(stmts...) + } + } + // Fallback: try to unmarshal into each variant sequentially + for _, vi := range variants { + g.Block( + Var().Id("v").Id(vi.typeName), + If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("==").Nil()).Block( + Id("u").Dot(vi.fieldName).Op("=").Op("&").Id("v"), + Return(Nil()), + ), + ) + } + g.Return(Nil()) + }) + // Marshal: pick first non-nil + f.Func().Params(Id("u").Id(name)).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).BlockFunc(func(g *Group) { + for _, vi := range variants { + g.If(Id("u").Dot(vi.fieldName).Op("!=").Nil()).Block( + Return(Qual("encoding/json", "Marshal").Call(Op("*").Id("u").Dot(vi.fieldName))), + ) + } + g.Return(Index().Byte().Values(), Nil()) + }) f.Line() } diff --git a/go/constants_gen.go b/go/constants_gen.go index 9e96b6e..1e08f03 100644 --- a/go/constants_gen.go +++ b/go/constants_gen.go @@ -22,6 +22,7 @@ const ( ClientMethodSessionRequestPermission = "session/request_permission" ClientMethodSessionUpdate = "session/update" ClientMethodTerminalCreate = "terminal/create" + ClientMethodTerminalKill = "terminal/kill" ClientMethodTerminalOutput = "terminal/output" ClientMethodTerminalRelease = "terminal/release" ClientMethodTerminalWaitForExit = "terminal/wait_for_exit" diff --git a/go/example/agent/main.go b/go/example/agent/main.go index 6942f4a..9084879 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -188,7 +188,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { SessionId: acp.SessionId(sid), ToolCall: acp.ToolCallUpdate{ ToolCallId: acp.ToolCallId("call_2"), - Title: "Modifying critical configuration file", + Title: acp.Ptr("Modifying critical configuration file"), Kind: acp.Ptr(acp.ToolKindEdit), Status: acp.Ptr(acp.ToolCallStatusPending), Locations: []acp.ToolCallLocation{{Path: "/home/user/project/config.json"}}, @@ -218,6 +218,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { ToolCallId: acp.ToolCallId("call_2"), Status: acp.Ptr(acp.ToolCallStatusCompleted), RawOutput: map[string]any{"success": true, "message": "Configuration updated"}, + Title: acp.Ptr("Modifying critical configuration file"), }}, }); err != nil { return err diff --git a/go/example/claude-code/main.go b/go/example/claude-code/main.go index c7edb41..296028e 100644 --- a/go/example/claude-code/main.go +++ b/go/example/claude-code/main.go @@ -37,7 +37,11 @@ func (c *replClient) RequestPermission(ctx context.Context, params acp.RequestPe return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Cancelled: &acp.RequestPermissionOutcomeCancelled{}}}, nil } - fmt.Printf("\n🔐 Permission requested: %s\n", params.ToolCall.Title) + title := "" + if params.ToolCall.Title != nil { + title = *params.ToolCall.Title + } + fmt.Printf("\n🔐 Permission requested: %s\n", title) fmt.Println("\nOptions:") for i, opt := range params.Options { fmt.Printf(" %d. %s (%s)\n", i+1, opt.Name, opt.Kind) @@ -115,16 +119,16 @@ func (c *replClient) ReadTextFile(ctx context.Context, params acp.ReadTextFileRe return acp.ReadTextFileResponse{}, fmt.Errorf("read %s: %w", params.Path, err) } content := string(b) - if params.Line > 0 || params.Limit > 0 { + if params.Line != nil || params.Limit != nil { lines := strings.Split(content, "\n") start := 0 - if params.Line > 0 { - start = min(max(params.Line-1, 0), len(lines)) + if params.Line != nil && *params.Line > 0 { + start = min(max(*params.Line-1, 0), len(lines)) } end := len(lines) - if params.Limit > 0 { - if start+params.Limit < end { - end = start + params.Limit + if params.Limit != nil && *params.Limit > 0 { + if start+*params.Limit < end { + end = start + *params.Limit } } content = strings.Join(lines[start:end], "\n") @@ -161,7 +165,7 @@ func main() { // Invoke Claude Code via npx ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cmd := exec.CommandContext(ctx, "npx", "-y", "@zed-industries/claude-code-acp") + cmd := exec.CommandContext(ctx, "npx", "-y", "@zed-industries/claude-code-acp@latest") cmd.Stderr = os.Stderr stdin, err := cmd.StdinPipe() if err != nil { @@ -235,7 +239,7 @@ func main() { cancel() return case ":cancel": - _ = conn.Cancel(ctx, acp.CancelNotification(newSess)) + _ = conn.Cancel(ctx, acp.CancelNotification{SessionId: newSess.SessionId}) continue } // Send prompt and wait for completion while streaming updates are printed via SessionUpdate diff --git a/go/example/client/main.go b/go/example/client/main.go index 670821e..6c6880d 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -21,7 +21,11 @@ var ( ) func (e *exampleClient) RequestPermission(ctx context.Context, params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { - fmt.Printf("\n🔐 Permission requested: %s\n", params.ToolCall.Title) + title := "" + if params.ToolCall.Title != nil { + title = *params.ToolCall.Title + } + fmt.Printf("\n🔐 Permission requested: %s\n", title) fmt.Println("\nOptions:") for i, opt := range params.Options { fmt.Printf(" %d. %s (%s)\n", i+1, opt.Name, opt.Kind) @@ -111,16 +115,16 @@ func (e *exampleClient) ReadTextFile(ctx context.Context, params acp.ReadTextFil } content := string(b) // Apply optional line/limit (1-based line index) - if params.Line > 0 || params.Limit > 0 { + if params.Line != nil || params.Limit != nil { lines := strings.Split(content, "\n") start := 0 - if params.Line > 0 { - start = min(max(params.Line-1, 0), len(lines)) + if params.Line != nil && *params.Line > 0 { + start = min(max(*params.Line-1, 0), len(lines)) } end := len(lines) - if params.Limit > 0 { - if start+params.Limit < end { - end = start + params.Limit + if params.Limit != nil && *params.Limit > 0 { + if start+*params.Limit < end { + end = start + *params.Limit } } content = strings.Join(lines[start:end], "\n") diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index cf317d7..70c34f5 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -40,7 +40,11 @@ func (c *replClient) RequestPermission(ctx context.Context, params acp.RequestPe return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Cancelled: &acp.RequestPermissionOutcomeCancelled{}}}, nil } - fmt.Printf("\n🔐 Permission requested: %s\n", params.ToolCall.Title) + title := "" + if params.ToolCall.Title != nil { + title = *params.ToolCall.Title + } + fmt.Printf("\n🔐 Permission requested: %s\n", title) fmt.Println("\nOptions:") for i, opt := range params.Options { fmt.Printf(" %d. %s (%s)\n", i+1, opt.Name, opt.Kind) @@ -118,16 +122,16 @@ func (c *replClient) ReadTextFile(ctx context.Context, params acp.ReadTextFileRe return acp.ReadTextFileResponse{}, fmt.Errorf("read %s: %w", params.Path, err) } content := string(b) - if params.Line > 0 || params.Limit > 0 { + if params.Line != nil || params.Limit != nil { lines := strings.Split(content, "\n") start := 0 - if params.Line > 0 { - start = min(max(params.Line-1, 0), len(lines)) + if params.Line != nil && *params.Line > 0 { + start = min(max(*params.Line-1, 0), len(lines)) } end := len(lines) - if params.Limit > 0 { - if start+params.Limit < end { - end = start + params.Limit + if params.Limit != nil && *params.Limit > 0 { + if start+*params.Limit < end { + end = start + *params.Limit } } content = strings.Join(lines[start:end], "\n") @@ -259,7 +263,7 @@ func main() { cancel() return case ":cancel": - _ = conn.Cancel(ctx, acp.CancelNotification(newSess)) + _ = conn.Cancel(ctx, acp.CancelNotification{SessionId: newSess.SessionId}) continue } // Send prompt and wait for completion while streaming updates are printed via SessionUpdate diff --git a/go/example_agent_test.go b/go/example_agent_test.go index 8c5fa1b..0e68180 100644 --- a/go/example_agent_test.go +++ b/go/example_agent_test.go @@ -53,7 +53,7 @@ func (a *agentExample) Prompt(ctx context.Context, p PromptRequest) (PromptRespo SessionId: p.SessionId, ToolCall: ToolCallUpdate{ ToolCallId: ToolCallId("call_1"), - Title: "Modifying configuration", + Title: Ptr("Modifying configuration"), Kind: Ptr(ToolKindEdit), Status: Ptr(ToolCallStatusPending), Locations: []ToolCallLocation{{Path: "/project/config.json"}}, diff --git a/go/example_client_test.go b/go/example_client_test.go index c4fad1c..6d3f5ba 100644 --- a/go/example_client_test.go +++ b/go/example_client_test.go @@ -40,7 +40,8 @@ func (clientExample) SessionUpdate(ctx context.Context, n SessionNotification) e fmt.Println("[", c.Type, "]") } case u.ToolCall != nil: - fmt.Printf("\n[tool] %s (%s)\n", u.ToolCall.Title, u.ToolCall.Status) + title := u.ToolCall.Title + fmt.Printf("\n[tool] %s (%s)\n", title, u.ToolCall.Status) case u.ToolCallUpdate != nil: fmt.Printf("\n[tool] %s -> %v\n", u.ToolCallUpdate.ToolCallId, u.ToolCallUpdate.Status) } @@ -66,20 +67,20 @@ func (clientExample) ReadTextFile(ctx context.Context, p ReadTextFileRequest) (R return ReadTextFileResponse{}, err } content := string(b) - if p.Line > 0 || p.Limit > 0 { + if p.Line != nil || p.Limit != nil { lines := strings.Split(content, "\n") start := 0 - if p.Line > 0 { - if p.Line-1 > 0 { - start = p.Line - 1 + if p.Line != nil && *p.Line > 0 { + if *p.Line-1 > 0 { + start = *p.Line - 1 } if start > len(lines) { start = len(lines) } } end := len(lines) - if p.Limit > 0 && start+p.Limit < end { - end = start + p.Limit + if p.Limit != nil && *p.Limit > 0 && start+*p.Limit < end { + end = start + *p.Limit } content = strings.Join(lines[start:end], "\n") } diff --git a/go/helpers_gen.go b/go/helpers_gen.go index 51d45f2..9a493a9 100644 --- a/go/helpers_gen.go +++ b/go/helpers_gen.go @@ -54,11 +54,10 @@ func ResourceBlock(res EmbeddedResource) ContentBlock { // ToolContent wraps a content block as tool-call content. func ToolContent(block ContentBlock) ToolCallContent { - var b ContentBlock = block - return ToolCallContent{ - Content: &b, + return ToolCallContent{Content: &ToolCallContentContent{ + Content: block, Type: "content", - } + }} } // ToolDiffContent constructs a diff tool-call content. If oldText is omitted, the field is left empty. @@ -67,22 +66,20 @@ func ToolDiffContent(path string, newText string, oldText ...string) ToolCallCon if len(oldText) > 0 { o = &oldText[0] } - return ToolCallContent{ - Diff: &DiffContent{ - NewText: newText, - OldText: o, - Path: path, - }, - Type: "diff", - } + return ToolCallContent{Diff: &ToolCallContentDiff{ + NewText: newText, + OldText: o, + Path: path, + Type: "diff", + }} } // ToolTerminalRef constructs a terminal reference tool-call content. func ToolTerminalRef(terminalId string) ToolCallContent { - return ToolCallContent{ - Terminal: &TerminalRef{TerminalId: terminalId}, - Type: "terminal", - } + return ToolCallContent{Terminal: &ToolCallContentTerminal{ + TerminalId: terminalId, + Type: "terminal", + }} } // Ptr returns a pointer to v. diff --git a/go/types_gen.go b/go/types_gen.go index b0e2b17..0579924 100644 --- a/go/types_gen.go +++ b/go/types_gen.go @@ -30,9 +30,9 @@ type AgentResponse any // Optional annotations for the client. The client can use annotations to inform how objects are used or displayed type Annotations struct { - Audience []Role `json:"audience,omitempty"` - LastModified string `json:"lastModified,omitempty"` - Priority float64 `json:"priority,omitempty"` + Audience []Role `json:"audience,omitempty"` + LastModified *string `json:"lastModified,omitempty"` + Priority *float64 `json:"priority,omitempty"` } // Audio provided to or from an LLM. @@ -45,7 +45,7 @@ type AudioContent struct { // Describes an available authentication method. type AuthMethod struct { // Optional description providing more details about this authentication method. - Description string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` // Unique identifier for this authentication method. Id AuthMethodId `json:"id"` // Human-readable name of the authentication method. @@ -65,11 +65,66 @@ func (v *AuthenticateRequest) Validate() error { return nil } +// Information about a command. +type AvailableCommand struct { + // Human-readable description of what the command does. + Description string `json:"description"` + // Input for the command if required + Input *AvailableCommandInput `json:"input,omitempty"` + // Command name (e.g., "create_plan", "research_codebase"). + Name string `json:"name"` +} + +// All text that was typed after the command name is provided as input. +type UnstructuredCommandInput struct { + // A brief description of the expected input + Hint string `json:"hint"` +} + +type AvailableCommandInput struct { + UnstructuredCommandInput *UnstructuredCommandInput `json:"-"` +} + +func (u *AvailableCommandInput) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + { + var v UnstructuredCommandInput + var match bool = true + if _, ok := m["hint"]; !ok { + match = false + } + if match { + if err := json.Unmarshal(b, &v); err != nil { + return err + } + u.UnstructuredCommandInput = &v + return nil + } + } + { + var v UnstructuredCommandInput + if err := json.Unmarshal(b, &v); err == nil { + u.UnstructuredCommandInput = &v + return nil + } + } + return nil +} +func (u AvailableCommandInput) MarshalJSON() ([]byte, error) { + if u.UnstructuredCommandInput != nil { + return json.Marshal(*u.UnstructuredCommandInput) + } + return []byte{}, nil +} + // Binary resource contents. type BlobResourceContents struct { - Blob string `json:"blob"` - MimeType string `json:"mimeType,omitempty"` - Uri string `json:"uri"` + Blob string `json:"blob"` + MimeType *string `json:"mimeType,omitempty"` + Uri string `json:"uri"` } // Notification to cancel ongoing operations for a session. See protocol docs: [Cancellation](https://agentclientprotocol.com/protocol/prompt-turn#cancellation) @@ -242,9 +297,9 @@ func (c *ContentBlock) Validate() error { type CreateTerminalRequest struct { Args []string `json:"args,omitempty"` Command string `json:"command"` - Cwd string `json:"cwd,omitempty"` + Cwd *string `json:"cwd,omitempty"` Env []EnvVariable `json:"env,omitempty"` - OutputByteLimit int `json:"outputByteLimit,omitempty"` + OutputByteLimit *int `json:"outputByteLimit,omitempty"` SessionId SessionId `json:"sessionId"` } @@ -278,29 +333,36 @@ type EmbeddedResourceResource struct { BlobResourceContents *BlobResourceContents `json:"-"` } -func (e *EmbeddedResourceResource) UnmarshalJSON(b []byte) error { +func (u *EmbeddedResourceResource) UnmarshalJSON(b []byte) error { var m map[string]json.RawMessage if err := json.Unmarshal(b, &m); err != nil { return err } - if _, ok := m["text"]; ok { + { var v TextResourceContents - if err := json.Unmarshal(b, &v); err != nil { - return err + if err := json.Unmarshal(b, &v); err == nil { + u.TextResourceContents = &v + return nil } - e.TextResourceContents = &v - return nil } - if _, ok2 := m["blob"]; ok2 { + { var v BlobResourceContents - if err := json.Unmarshal(b, &v); err != nil { - return err + if err := json.Unmarshal(b, &v); err == nil { + u.BlobResourceContents = &v + return nil } - e.BlobResourceContents = &v - return nil } return nil } +func (u EmbeddedResourceResource) MarshalJSON() ([]byte, error) { + if u.TextResourceContents != nil { + return json.Marshal(*u.TextResourceContents) + } + if u.BlobResourceContents != nil { + return json.Marshal(*u.BlobResourceContents) + } + return []byte{}, nil +} // An environment variable to set when launching an MCP server. type EnvVariable struct { @@ -323,7 +385,7 @@ type ImageContent struct { Annotations *Annotations `json:"annotations,omitempty"` Data string `json:"data"` MimeType string `json:"mimeType"` - Uri string `json:"uri,omitempty"` + Uri *string `json:"uri,omitempty"` } // Request parameters for the initialize method. Sent by the client to establish connection and negotiate capabilities. See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) @@ -352,6 +414,18 @@ func (v *InitializeResponse) Validate() error { return nil } +type KillTerminalRequest struct { + SessionId SessionId `json:"sessionId"` + TerminalId string `json:"terminalId"` +} + +func (v *KillTerminalRequest) Validate() error { + if v.TerminalId == "" { + return fmt.Errorf("terminalId is required") + } + return nil +} + // Request parameters for loading an existing session. Only available if the agent supports the 'loadSession' capability. See protocol docs: [Loading Sessions](https://agentclientprotocol.com/protocol/session-setup#loading-sessions) type LoadSessionRequest struct { // The working directory for this session. @@ -404,6 +478,8 @@ func (v *NewSessionRequest) Validate() error { // Response from creating a new session. See protocol docs: [Creating a Session](https://agentclientprotocol.com/protocol/session-setup#creating-a-session) type NewSessionResponse struct { + // **UNSTABLE** Commands that may be executed via 'session/prompt' requests + AvailableCommands []AvailableCommand `json:"availableCommands,omitempty"` // Unique identifier for the created session. Used in all subsequent requests for this conversation. SessionId SessionId `json:"sessionId"` } @@ -510,9 +586,9 @@ type ProtocolVersion int // Request to read content from a text file. Only available if the client supports the 'fs.readTextFile' capability. type ReadTextFileRequest struct { // Optional maximum number of lines to read. - Limit int `json:"limit,omitempty"` + Limit *int `json:"limit,omitempty"` // Optional line number to start reading from (1-based). - Line int `json:"line,omitempty"` + Line *int `json:"line,omitempty"` // Absolute path to the file to read. Path string `json:"path"` // The session ID for this request. @@ -551,9 +627,16 @@ func (v *ReleaseTerminalRequest) Validate() error { } // The outcome of a permission request. -type RequestPermissionOutcomeCancelled struct{} +// The prompt turn was cancelled before the user responded. When a client sends a 'session/cancel' notification to cancel an ongoing prompt turn, it MUST respond to all pending 'session/request_permission' requests with this 'Cancelled' outcome. See protocol docs: [Cancellation](https://agentclientprotocol.com/protocol/prompt-turn#cancellation) +type RequestPermissionOutcomeCancelled struct { + Outcome string `json:"outcome"` +} + +// The user selected one of the provided options. type RequestPermissionOutcomeSelected struct { + // The ID of the option the user selected. OptionId PermissionOptionId `json:"optionId"` + Outcome string `json:"outcome"` } type RequestPermissionOutcome struct { @@ -561,38 +644,94 @@ type RequestPermissionOutcome struct { Selected *RequestPermissionOutcomeSelected `json:"-"` } -func (o *RequestPermissionOutcome) UnmarshalJSON(b []byte) error { +func (u *RequestPermissionOutcome) UnmarshalJSON(b []byte) error { var m map[string]json.RawMessage if err := json.Unmarshal(b, &m); err != nil { return err } - var outcome string - if v, ok := m["outcome"]; ok { - json.Unmarshal(v, &outcome) + { + var v RequestPermissionOutcomeCancelled + var match bool = true + if _, ok := m["outcome"]; !ok { + match = false + } + var raw json.RawMessage + var ok bool + raw, ok = m["outcome"] + if !ok { + match = false + } + if ok { + var tmp any + if err := json.Unmarshal(raw, &tmp); err != nil { + return err + } + if fmt.Sprint(tmp) != fmt.Sprint("cancelled") { + match = false + } + } + if match { + if err := json.Unmarshal(b, &v); err != nil { + return err + } + u.Cancelled = &v + return nil + } } - switch outcome { - case "cancelled": - o.Cancelled = &RequestPermissionOutcomeCancelled{} - return nil - case "selected": - var v2 RequestPermissionOutcomeSelected - if err := json.Unmarshal(b, &v2); err != nil { - return err + { + var v RequestPermissionOutcomeSelected + var match bool = true + if _, ok := m["outcome"]; !ok { + match = false + } + if _, ok := m["optionId"]; !ok { + match = false + } + var raw json.RawMessage + var ok bool + raw, ok = m["outcome"] + if !ok { + match = false + } + if ok { + var tmp any + if err := json.Unmarshal(raw, &tmp); err != nil { + return err + } + if fmt.Sprint(tmp) != fmt.Sprint("selected") { + match = false + } + } + if match { + if err := json.Unmarshal(b, &v); err != nil { + return err + } + u.Selected = &v + return nil + } + } + { + var v RequestPermissionOutcomeCancelled + if err := json.Unmarshal(b, &v); err == nil { + u.Cancelled = &v + return nil + } + } + { + var v RequestPermissionOutcomeSelected + if err := json.Unmarshal(b, &v); err == nil { + u.Selected = &v + return nil } - o.Selected = &v2 - return nil } return nil } -func (o RequestPermissionOutcome) MarshalJSON() ([]byte, error) { - if o.Cancelled != nil { - return json.Marshal(map[string]any{"outcome": "cancelled"}) +func (u RequestPermissionOutcome) MarshalJSON() ([]byte, error) { + if u.Cancelled != nil { + return json.Marshal(*u.Cancelled) } - if o.Selected != nil { - return json.Marshal(map[string]any{ - "optionId": o.Selected.OptionId, - "outcome": "selected", - }) + if u.Selected != nil { + return json.Marshal(*u.Selected) } return []byte{}, nil } @@ -627,11 +766,11 @@ func (v *RequestPermissionResponse) Validate() error { // A resource that the server is capable of reading, included in a prompt or tool call result. type ResourceLink struct { Annotations *Annotations `json:"annotations,omitempty"` - Description string `json:"description,omitempty"` - MimeType string `json:"mimeType,omitempty"` + Description *string `json:"description,omitempty"` + MimeType *string `json:"mimeType,omitempty"` Name string `json:"name"` - Size int `json:"size,omitempty"` - Title string `json:"title,omitempty"` + Size *int `json:"size,omitempty"` + Title *string `json:"title,omitempty"` Uri string `json:"uri"` } @@ -848,8 +987,8 @@ const ( ) type TerminalExitStatus struct { - ExitCode int `json:"exitCode,omitempty"` - Signal string `json:"signal,omitempty"` + ExitCode *int `json:"exitCode,omitempty"` + Signal *string `json:"signal,omitempty"` } type TerminalOutputRequest struct { @@ -885,9 +1024,9 @@ type TextContent struct { // Text-based resource contents. type TextResourceContents struct { - MimeType string `json:"mimeType,omitempty"` - Text string `json:"text"` - Uri string `json:"uri"` + MimeType *string `json:"mimeType,omitempty"` + Text string `json:"text"` + Uri string `json:"uri"` } // Represents a tool call that the language model has requested. Tool calls are actions that the agent executes on behalf of the language model, such as reading files, executing code, or fetching data from external sources. See protocol docs: [Tool Calls](https://agentclientprotocol.com/protocol/tool-calls) @@ -911,73 +1050,174 @@ type ToolCall struct { } // Content produced by a tool call. Tool calls can produce different types of content including standard content blocks (text, images) or file diffs. See protocol docs: [Content](https://agentclientprotocol.com/protocol/tool-calls#content) -type DiffContent struct { - NewText string `json:"newText"` +// Standard content block (text, images, resources). +type ToolCallContentContent struct { + // The actual content block. + Content ContentBlock `json:"content"` + Type string `json:"type"` +} + +// File modification shown as a diff. +type ToolCallContentDiff struct { + // The new content after modification. + NewText string `json:"newText"` + // The original content (None for new files). OldText *string `json:"oldText,omitempty"` - Path string `json:"path"` + // The file path being modified. + Path string `json:"path"` + Type string `json:"type"` } -type TerminalRef struct { + +type ToolCallContentTerminal struct { TerminalId string `json:"terminalId"` + Type string `json:"type"` } type ToolCallContent struct { - Type string `json:"type"` - Content *ContentBlock `json:"-"` - Diff *DiffContent `json:"-"` - Terminal *TerminalRef `json:"-"` + Content *ToolCallContentContent `json:"-"` + Diff *ToolCallContentDiff `json:"-"` + Terminal *ToolCallContentTerminal `json:"-"` } -func (t *ToolCallContent) UnmarshalJSON(b []byte) error { - var probe struct { - Type string `json:"type"` - } - if err := json.Unmarshal(b, &probe); err != nil { +func (u *ToolCallContent) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { return err } - t.Type = probe.Type - switch probe.Type { - case "content": - var v struct { - Type string `json:"type"` - Content ContentBlock `json:"content"` + { + var v ToolCallContentContent + var match bool = true + if _, ok := m["type"]; !ok { + match = false } - if err := json.Unmarshal(b, &v); err != nil { - return err + if _, ok := m["content"]; !ok { + match = false } - t.Content = &v.Content - case "diff": - var v DiffContent - if err := json.Unmarshal(b, &v); err != nil { - return err + var raw json.RawMessage + var ok bool + raw, ok = m["type"] + if !ok { + match = false } - t.Diff = &v - case "terminal": - var v TerminalRef - if err := json.Unmarshal(b, &v); err != nil { - return err + if ok { + var tmp any + if err := json.Unmarshal(raw, &tmp); err != nil { + return err + } + if fmt.Sprint(tmp) != fmt.Sprint("content") { + match = false + } + } + if match { + if err := json.Unmarshal(b, &v); err != nil { + return err + } + u.Content = &v + return nil } - t.Terminal = &v } - return nil -} - -func (t *ToolCallContent) Validate() error { - switch t.Type { - case "content": - if t.Content == nil { - return fmt.Errorf("toolcallcontent.content missing") + { + var v ToolCallContentDiff + var match bool = true + if _, ok := m["type"]; !ok { + match = false + } + if _, ok := m["path"]; !ok { + match = false + } + if _, ok := m["newText"]; !ok { + match = false + } + var raw json.RawMessage + var ok bool + raw, ok = m["type"] + if !ok { + match = false + } + if ok { + var tmp any + if err := json.Unmarshal(raw, &tmp); err != nil { + return err + } + if fmt.Sprint(tmp) != fmt.Sprint("diff") { + match = false + } + } + if match { + if err := json.Unmarshal(b, &v); err != nil { + return err + } + u.Diff = &v + return nil + } + } + { + var v ToolCallContentTerminal + var match bool = true + if _, ok := m["type"]; !ok { + match = false } - case "diff": - if t.Diff == nil { - return fmt.Errorf("toolcallcontent.diff missing") + if _, ok := m["terminalId"]; !ok { + match = false } - case "terminal": - if t.Terminal == nil { - return fmt.Errorf("toolcallcontent.terminal missing") + var raw json.RawMessage + var ok bool + raw, ok = m["type"] + if !ok { + match = false + } + if ok { + var tmp any + if err := json.Unmarshal(raw, &tmp); err != nil { + return err + } + if fmt.Sprint(tmp) != fmt.Sprint("terminal") { + match = false + } + } + if match { + if err := json.Unmarshal(b, &v); err != nil { + return err + } + u.Terminal = &v + return nil + } + } + { + var v ToolCallContentContent + if err := json.Unmarshal(b, &v); err == nil { + u.Content = &v + return nil + } + } + { + var v ToolCallContentDiff + if err := json.Unmarshal(b, &v); err == nil { + u.Diff = &v + return nil + } + } + { + var v ToolCallContentTerminal + if err := json.Unmarshal(b, &v); err == nil { + u.Terminal = &v + return nil } } return nil } +func (u ToolCallContent) MarshalJSON() ([]byte, error) { + if u.Content != nil { + return json.Marshal(*u.Content) + } + if u.Diff != nil { + return json.Marshal(*u.Diff) + } + if u.Terminal != nil { + return json.Marshal(*u.Terminal) + } + return []byte{}, nil +} // Unique identifier for a tool call within a session. type ToolCallId string @@ -985,7 +1225,7 @@ type ToolCallId string // A file location being accessed or modified by a tool. Enables clients to implement "follow-along" features that track which files the agent is working with in real-time. See protocol docs: [Following the Agent](https://agentclientprotocol.com/protocol/tool-calls#following-the-agent) type ToolCallLocation struct { // Optional line number within the file. - Line int `json:"line,omitempty"` + Line *int `json:"line,omitempty"` // The file path being accessed or modified. Path string `json:"path"` } @@ -1015,7 +1255,7 @@ type ToolCallUpdate struct { // Update the execution status. Status *ToolCallStatus `json:"status,omitempty"` // Update the human-readable title. - Title string `json:"title,omitempty"` + Title *string `json:"title,omitempty"` // The ID of the tool call being updated. ToolCallId ToolCallId `json:"toolCallId"` } @@ -1055,8 +1295,8 @@ func (v *WaitForTerminalExitRequest) Validate() error { } type WaitForTerminalExitResponse struct { - ExitCode int `json:"exitCode,omitempty"` - Signal string `json:"signal,omitempty"` + ExitCode *int `json:"exitCode,omitempty"` + Signal *string `json:"signal,omitempty"` } func (v *WaitForTerminalExitResponse) Validate() error { From 9de508b5c916634cde9d5d7d365822a563e2a3f7 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Tue, 2 Sep 2025 12:29:05 +0200 Subject: [PATCH 12/22] test: add JSON parity tests with golden file validation Change-Id: Iad1f284ea06288dd72e8eab3b6c429ea88113fee Signed-off-by: Thomas Kosiewski --- go/acp_test.go | 6 +- go/cmd/generate/internal/emit/helpers.go | 304 ++- go/cmd/generate/internal/emit/jenwrap.go | 3 + go/cmd/generate/internal/emit/types.go | 725 +++---- go/example/agent/main.go | 89 +- go/example/claude-code/main.go | 8 +- go/example/client/main.go | 4 +- go/example/gemini/main.go | 8 +- go/example_agent_test.go | 34 +- go/example_client_test.go | 4 +- go/example_gemini_test.go | 2 +- go/helpers_gen.go | 356 +++- go/json_parity_test.go | 209 ++ .../json_golden/cancel_notification.json | 4 + go/testdata/json_golden/content_audio.json | 6 + go/testdata/json_golden/content_image.json | 6 + .../json_golden/content_resource_blob.json | 9 + .../json_golden/content_resource_link.json | 8 + .../json_golden/content_resource_text.json | 9 + go/testdata/json_golden/content_text.json | 5 + .../fs_read_text_file_request.json | 7 + .../fs_read_text_file_response.json | 4 + .../fs_write_text_file_request.json | 6 + .../json_golden/initialize_request.json | 10 + .../json_golden/initialize_response.json | 13 + .../json_golden/new_session_request.json | 12 + .../json_golden/new_session_response.json | 4 + .../permission_outcome_cancelled.json | 4 + .../permission_outcome_selected.json | 5 + go/testdata/json_golden/prompt_request.json | 18 + .../request_permission_request.json | 19 + .../request_permission_response_selected.json | 7 + .../session_update_agent_message_chunk.json | 8 + .../session_update_agent_thought_chunk.json | 8 + .../json_golden/session_update_plan.json | 16 + .../json_golden/session_update_tool_call.json | 8 + ...ssion_update_tool_call_update_content.json | 15 + .../session_update_user_message_chunk.json | 8 + .../tool_content_content_text.json | 8 + .../json_golden/tool_content_diff.json | 7 + go/types_gen.go | 1836 ++++++++++++++--- package.json | 2 +- 42 files changed, 2887 insertions(+), 937 deletions(-) create mode 100644 go/json_parity_test.go create mode 100644 go/testdata/json_golden/cancel_notification.json create mode 100644 go/testdata/json_golden/content_audio.json create mode 100644 go/testdata/json_golden/content_image.json create mode 100644 go/testdata/json_golden/content_resource_blob.json create mode 100644 go/testdata/json_golden/content_resource_link.json create mode 100644 go/testdata/json_golden/content_resource_text.json create mode 100644 go/testdata/json_golden/content_text.json create mode 100644 go/testdata/json_golden/fs_read_text_file_request.json create mode 100644 go/testdata/json_golden/fs_read_text_file_response.json create mode 100644 go/testdata/json_golden/fs_write_text_file_request.json create mode 100644 go/testdata/json_golden/initialize_request.json create mode 100644 go/testdata/json_golden/initialize_response.json create mode 100644 go/testdata/json_golden/new_session_request.json create mode 100644 go/testdata/json_golden/new_session_response.json create mode 100644 go/testdata/json_golden/permission_outcome_cancelled.json create mode 100644 go/testdata/json_golden/permission_outcome_selected.json create mode 100644 go/testdata/json_golden/prompt_request.json create mode 100644 go/testdata/json_golden/request_permission_request.json create mode 100644 go/testdata/json_golden/request_permission_response_selected.json create mode 100644 go/testdata/json_golden/session_update_agent_message_chunk.json create mode 100644 go/testdata/json_golden/session_update_agent_thought_chunk.json create mode 100644 go/testdata/json_golden/session_update_plan.json create mode 100644 go/testdata/json_golden/session_update_tool_call.json create mode 100644 go/testdata/json_golden/session_update_tool_call_update_content.json create mode 100644 go/testdata/json_golden/session_update_user_message_chunk.json create mode 100644 go/testdata/json_golden/tool_content_content_text.json create mode 100644 go/testdata/json_golden/tool_content_diff.json diff --git a/go/acp_test.go b/go/acp_test.go index 06cd606..96b7958 100644 --- a/go/acp_test.go +++ b/go/acp_test.go @@ -367,7 +367,11 @@ func TestConnectionHandlesNotifications(t *testing.T) { if err := agentSide.SessionUpdate(context.Background(), SessionNotification{ SessionId: "test-session", - Update: SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: ContentBlock{Type: "text", Text: &TextContent{Text: "Hello from agent"}}}}, + Update: SessionUpdate{ + AgentMessageChunk: &SessionUpdateAgentMessageChunk{ + Content: TextBlock("Hello from agent"), + }, + }, }); err != nil { t.Fatalf("sessionUpdate error: %v", err) } diff --git a/go/cmd/generate/internal/emit/helpers.go b/go/cmd/generate/internal/emit/helpers.go index a9c1da0..3b6e5dc 100644 --- a/go/cmd/generate/internal/emit/helpers.go +++ b/go/cmd/generate/internal/emit/helpers.go @@ -2,15 +2,19 @@ package emit import ( "bytes" + "fmt" "os" "path/filepath" + "sort" + "strings" "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/load" + "github.com/zed-industries/agent-client-protocol/go/cmd/generate/internal/util" ) // WriteHelpersJen emits go/helpers_gen.go with small constructor helpers // for common union variants and a Ptr generic helper. -func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { +func WriteHelpersJen(outDir string, schema *load.Schema, _ *load.Meta) error { f := NewFile("acp") f.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") @@ -18,8 +22,7 @@ func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { f.Comment("TextBlock constructs a text content block.") f.Func().Id("TextBlock").Params(Id("text").String()).Id("ContentBlock").Block( Return(Id("ContentBlock").Values(Dict{ - Id("Type"): Lit("text"), - Id("Text"): Op("&").Id("TextContent").Values(Dict{Id("Text"): Id("text")}), + Id("Text"): Op("&").Id("ContentBlockText").Values(Dict{Id("Type"): Lit("text"), Id("Text"): Id("text")}), })), ) f.Line() @@ -27,8 +30,7 @@ func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { f.Comment("ImageBlock constructs an inline image content block with base64-encoded data.") f.Func().Id("ImageBlock").Params(Id("data").String(), Id("mimeType").String()).Id("ContentBlock").Block( Return(Id("ContentBlock").Values(Dict{ - Id("Type"): Lit("image"), - Id("Image"): Op("&").Id("ImageContent").Values(Dict{Id("Data"): Id("data"), Id("MimeType"): Id("mimeType")}), + Id("Image"): Op("&").Id("ContentBlockImage").Values(Dict{Id("Type"): Lit("image"), Id("Data"): Id("data"), Id("MimeType"): Id("mimeType")}), })), ) f.Line() @@ -36,8 +38,7 @@ func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { f.Comment("AudioBlock constructs an inline audio content block with base64-encoded data.") f.Func().Id("AudioBlock").Params(Id("data").String(), Id("mimeType").String()).Id("ContentBlock").Block( Return(Id("ContentBlock").Values(Dict{ - Id("Type"): Lit("audio"), - Id("Audio"): Op("&").Id("AudioContent").Values(Dict{Id("Data"): Id("data"), Id("MimeType"): Id("mimeType")}), + Id("Audio"): Op("&").Id("ContentBlockAudio").Values(Dict{Id("Type"): Lit("audio"), Id("Data"): Id("data"), Id("MimeType"): Id("mimeType")}), })), ) f.Line() @@ -45,8 +46,7 @@ func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { f.Comment("ResourceLinkBlock constructs a resource_link content block with a name and URI.") f.Func().Id("ResourceLinkBlock").Params(Id("name").String(), Id("uri").String()).Id("ContentBlock").Block( Return(Id("ContentBlock").Values(Dict{ - Id("Type"): Lit("resource_link"), - Id("ResourceLink"): Op("&").Id("ResourceLinkContent").Values(Dict{Id("Name"): Id("name"), Id("Uri"): Id("uri")}), + Id("ResourceLink"): Op("&").Id("ContentBlockResourceLink").Values(Dict{Id("Type"): Lit("resource_link"), Id("Name"): Id("name"), Id("Uri"): Id("uri")}), })), ) f.Line() @@ -55,8 +55,7 @@ func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { f.Func().Id("ResourceBlock").Params(Id("res").Id("EmbeddedResource")).Id("ContentBlock").Block( Var().Id("r").Id("EmbeddedResource").Op("=").Id("res"), Return(Id("ContentBlock").Values(Dict{ - Id("Type"): Lit("resource"), - Id("Resource"): Op("&").Id("r"), + Id("Resource"): Op("&").Id("ContentBlockResource").Values(Dict{Id("Type"): Lit("resource"), Id("Resource"): Id("r").Dot("Resource")}), })), ) f.Line() @@ -96,9 +95,292 @@ func WriteHelpersJen(outDir string, _ *load.Schema, _ *load.Meta) error { Return(Op("&").Id("v")), ) + // SessionUpdate helpers (friendly aliases) + f.Line() + f.Comment("UpdateUserMessage constructs a user_message_chunk update with the given content.") + f.Func().Id("UpdateUserMessage").Params(Id("content").Id("ContentBlock")).Id("SessionUpdate").Block( + Return(Id("SessionUpdate").Values(Dict{Id("UserMessageChunk"): Op("&").Id("SessionUpdateUserMessageChunk").Values(Dict{Id("Content"): Id("content")})})), + ) + f.Comment("UpdateUserMessageText constructs a user_message_chunk update from text.") + f.Func().Id("UpdateUserMessageText").Params(Id("text").String()).Id("SessionUpdate").Block( + Return(Id("UpdateUserMessage").Call(Id("TextBlock").Call(Id("text")))), + ) + + f.Comment("UpdateAgentMessage constructs an agent_message_chunk update with the given content.") + f.Func().Id("UpdateAgentMessage").Params(Id("content").Id("ContentBlock")).Id("SessionUpdate").Block( + Return(Id("SessionUpdate").Values(Dict{Id("AgentMessageChunk"): Op("&").Id("SessionUpdateAgentMessageChunk").Values(Dict{Id("Content"): Id("content")})})), + ) + f.Comment("UpdateAgentMessageText constructs an agent_message_chunk update from text.") + f.Func().Id("UpdateAgentMessageText").Params(Id("text").String()).Id("SessionUpdate").Block( + Return(Id("UpdateAgentMessage").Call(Id("TextBlock").Call(Id("text")))), + ) + + f.Comment("UpdateAgentThought constructs an agent_thought_chunk update with the given content.") + f.Func().Id("UpdateAgentThought").Params(Id("content").Id("ContentBlock")).Id("SessionUpdate").Block( + Return(Id("SessionUpdate").Values(Dict{Id("AgentThoughtChunk"): Op("&").Id("SessionUpdateAgentThoughtChunk").Values(Dict{Id("Content"): Id("content")})})), + ) + f.Comment("UpdateAgentThoughtText constructs an agent_thought_chunk update from text.") + f.Func().Id("UpdateAgentThoughtText").Params(Id("text").String()).Id("SessionUpdate").Block( + Return(Id("UpdateAgentThought").Call(Id("TextBlock").Call(Id("text")))), + ) + + f.Comment("UpdatePlan constructs a plan update with the provided entries.") + f.Func().Id("UpdatePlan").Params(Id("entries").Op("...").Id("PlanEntry")).Id("SessionUpdate").Block( + Return(Id("SessionUpdate").Values(Dict{Id("Plan"): Op("&").Id("SessionUpdatePlan").Values(Dict{Id("Entries"): Id("entries")})})), + ) + + // Tool call start helpers with functional options (friendly aliases) + f.Line() + f.Type().Id("ToolCallStartOpt").Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")) + f.Comment("StartToolCall constructs a tool_call update with required fields and applies optional modifiers.") + f.Func().Id("StartToolCall").Params(Id("id").Id("ToolCallId"), Id("title").String(), Id("opts").Op("...").Id("ToolCallStartOpt")).Id("SessionUpdate").Block( + Id("tc").Op(":=").Id("SessionUpdateToolCall").Values(Dict{Id("ToolCallId"): Id("id"), Id("Title"): Id("title")}), + For(List(Id("_"), Id("opt")).Op(":=").Range().Id("opts")).Block(Id("opt").Call(Op("&").Id("tc"))), + Return(Id("SessionUpdate").Values(Dict{Id("ToolCall"): Op("&").Id("tc")})), + ) + f.Comment("WithStartKind sets the kind for a tool_call start update.") + f.Func().Id("WithStartKind").Params(Id("k").Id("ToolKind")).Id("ToolCallStartOpt").Block( + Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("Kind").Op("=").Id("k"))), + ) + f.Comment("WithStartStatus sets the status for a tool_call start update.") + f.Func().Id("WithStartStatus").Params(Id("s").Id("ToolCallStatus")).Id("ToolCallStartOpt").Block( + Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("Status").Op("=").Id("s"))), + ) + f.Comment("WithStartContent sets the initial content for a tool_call start update.") + f.Func().Id("WithStartContent").Params(Id("c").Index().Id("ToolCallContent")).Id("ToolCallStartOpt").Block( + Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("Content").Op("=").Id("c"))), + ) + f.Comment("WithStartLocations sets file locations and, if a single path is provided and rawInput is empty, mirrors it as rawInput.path.") + f.Func().Id("WithStartLocations").Params(Id("l").Index().Id("ToolCallLocation")).Id("ToolCallStartOpt").Block( + Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).BlockFunc(func(g *Group) { + g.Id("tc").Dot("Locations").Op("=").Id("l") + g.If(Id("len").Call(Id("l")).Op("==").Lit(1).Op("&&").Id("l").Index(Lit(0)).Dot("Path").Op("!=").Lit("")).BlockFunc(func(h *Group) { + // initialize rawInput if nil + h.If(Id("tc").Dot("RawInput").Op("==").Nil()).Block( + Id("tc").Dot("RawInput").Op("=").Map(String()).Any().Values(Dict{Lit("path"): Id("l").Index(Lit(0)).Dot("Path")}), + ).Else().BlockFunc(func(b *Group) { + b.List(Id("m"), Id("ok")).Op(":=").Id("tc").Dot("RawInput").Assert(Map(String()).Any()) + b.If(Id("ok")).Block( + If(List(Id("_"), Id("exists")).Op(":=").Id("m").Index(Lit("path")), Op("!").Id("exists")).Block( + Id("m").Index(Lit("path")).Op("=").Id("l").Index(Lit(0)).Dot("Path"), + ), + ) + }) + }) + })), + ) + f.Comment("WithStartRawInput sets rawInput for a tool_call start update.") + f.Func().Id("WithStartRawInput").Params(Id("v").Any()).Id("ToolCallStartOpt").Block( + Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("RawInput").Op("=").Id("v"))), + ) + f.Comment("WithStartRawOutput sets rawOutput for a tool_call start update.") + f.Func().Id("WithStartRawOutput").Params(Id("v").Any()).Id("ToolCallStartOpt").Block( + Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("RawOutput").Op("=").Id("v"))), + ) + + // Tool call update helpers with functional options (pointer fields; friendly aliases) + f.Line() + f.Type().Id("ToolCallUpdateOpt").Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")) + f.Comment("UpdateToolCall constructs a tool_call_update with the given ID and applies optional modifiers.") + f.Func().Id("UpdateToolCall").Params(Id("id").Id("ToolCallId"), Id("opts").Op("...").Id("ToolCallUpdateOpt")).Id("SessionUpdate").Block( + Id("tu").Op(":=").Id("SessionUpdateToolCallUpdate").Values(Dict{Id("ToolCallId"): Id("id")}), + For(List(Id("_"), Id("opt")).Op(":=").Range().Id("opts")).Block(Id("opt").Call(Op("&").Id("tu"))), + Return(Id("SessionUpdate").Values(Dict{Id("ToolCallUpdate"): Op("&").Id("tu")})), + ) + f.Comment("WithUpdateTitle sets the title for a tool_call_update.") + f.Func().Id("WithUpdateTitle").Params(Id("t").String()).Id("ToolCallUpdateOpt").Block( + Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Title").Op("=").Id("Ptr").Call(Id("t")))), + ) + f.Comment("WithUpdateKind sets the kind for a tool_call_update.") + f.Func().Id("WithUpdateKind").Params(Id("k").Id("ToolKind")).Id("ToolCallUpdateOpt").Block( + Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Kind").Op("=").Id("Ptr").Call(Id("k")))), + ) + f.Comment("WithUpdateStatus sets the status for a tool_call_update.") + f.Func().Id("WithUpdateStatus").Params(Id("s").Id("ToolCallStatus")).Id("ToolCallUpdateOpt").Block( + Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Status").Op("=").Id("Ptr").Call(Id("s")))), + ) + f.Comment("WithUpdateContent replaces the content collection for a tool_call_update.") + f.Func().Id("WithUpdateContent").Params(Id("c").Index().Id("ToolCallContent")).Id("ToolCallUpdateOpt").Block( + Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Content").Op("=").Id("c"))), + ) + f.Comment("WithUpdateLocations replaces the locations collection for a tool_call_update.") + f.Func().Id("WithUpdateLocations").Params(Id("l").Index().Id("ToolCallLocation")).Id("ToolCallUpdateOpt").Block( + Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Locations").Op("=").Id("l"))), + ) + f.Comment("WithUpdateRawInput sets rawInput for a tool_call_update.") + f.Func().Id("WithUpdateRawInput").Params(Id("v").Any()).Id("ToolCallUpdateOpt").Block( + Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("RawInput").Op("=").Id("v"))), + ) + f.Comment("WithUpdateRawOutput sets rawOutput for a tool_call_update.") + f.Func().Id("WithUpdateRawOutput").Params(Id("v").Any()).Id("ToolCallUpdateOpt").Block( + Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("RawOutput").Op("=").Id("v"))), + ) + + // Schema-driven generic helpers: New(required fields only) + // Iterate definitions deterministically + keys := make([]string, 0, len(schema.Defs)) + for k := range schema.Defs { + keys = append(keys, k) + } + sort.Strings(keys) + for _, name := range keys { + def := schema.Defs[name] + if def == nil || def.DocsIgnore || len(def.OneOf) == 0 { + continue + } + // Skip string-const unions + if isStringConstUnion(def) { + continue + } + // Build variant info similarly to types emitter + type vinfo struct { + fieldName string + typeName string + discKey string + discValue string + required []string + props map[string]*load.Definition + } + discKey := "" + for _, v := range def.OneOf { + if v == nil { + continue + } + for k, pd := range v.Properties { + if pd != nil && pd.Const != nil { + discKey = k + break + } + } + if discKey != "" { + break + } + } + variants := []vinfo{} + for idx, v := range def.OneOf { + if v == nil { + continue + } + // compute type name per types emitter + tname := v.Title + if tname == "" { + if v.Ref != "" && strings.HasPrefix(v.Ref, "#/$defs/") { + tname = v.Ref[len("#/$defs/"):] + } else { + if discKey != "" { + if pd := v.Properties[discKey]; pd != nil && pd.Const != nil { + s := fmt.Sprint(pd.Const) + tname = name + util.ToExportedField(s) + } + } + if tname == "" { + tname = name + fmt.Sprintf("Variant%d", idx+1) + } + } + } + fieldName := tname + dv := "" + if discKey != "" { + if pd := v.Properties[discKey]; pd != nil && pd.Const != nil { + s := fmt.Sprint(pd.Const) + fieldName = util.ToExportedField(s) + dv = s + } + } + // collect required + req := make([]string, len(v.Required)) + copy(req, v.Required) + variants = append(variants, vinfo{fieldName: fieldName, typeName: tname, discKey: discKey, discValue: dv, required: req, props: v.Properties}) + } + // Emit helper per variant: func New(...) + for _, vi := range variants { + // params: all required props except const discriminator + params := []Code{} + assigns := Dict{} + for _, rk := range vi.required { + if rk == vi.discKey { + continue + } + pd := vi.props[rk] + if pd == nil { + continue + } + // build param using lower-cased name + pname := rk + // field id for struct literal + field := util.ToExportedField(rk) + params = append(params, Id(pname).Add(jenTypeFor(pd))) + assigns[Id(field)] = Id(pname) + } + // include const discriminant if present and field exists on struct + if vi.discKey != "" && vi.discValue != "" { + assigns[Id(util.ToExportedField(vi.discKey))] = Lit(vi.discValue) + } + // Construct variant literal and wrap + f.Comment(fmt.Sprintf("New%s%s constructs a %s using the '%s' variant.", name, vi.fieldName, name, vi.discValue)) + f.Func().Id("New" + name + vi.fieldName).Params(params...).Id(name).Block( + Return( + Id(name).Values(Dict{ + Id(vi.fieldName): Op("&").Id(vi.typeName).Values(assigns), + }), + ), + ) + f.Line() + } + } + + // Friendly aliases: opinionated tool call starters for common cases + // StartReadToolCall: sets kind=read, status=pending, locations=[{path}], rawInput={path} + f.Comment("StartReadToolCall constructs a 'tool_call' update for reading a file: kind=read, status=pending, locations=[{path}], rawInput={path}.") + f.Func().Id("StartReadToolCall").Params( + Id("id").Id("ToolCallId"), + Id("title").String(), + Id("path").String(), + Id("opts").Op("...").Id("ToolCallStartOpt"), + ).Id("SessionUpdate").Block( + Id("base").Op(":=").Index().Id("ToolCallStartOpt").Values( + Id("WithStartKind").Call(Id("ToolKindRead")), + Id("WithStartStatus").Call(Id("ToolCallStatusPending")), + Id("WithStartLocations").Call( + Index().Id("ToolCallLocation").Values( + Id("ToolCallLocation").Values(Dict{Id("Path"): Id("path")}), + ), + ), + Id("WithStartRawInput").Call(Map(String()).Any().Values(Dict{Lit("path"): Id("path")})), + ), + Id("args").Op(":=").Id("append").Call(Id("base"), Id("opts").Op("...")), + Return(Id("StartToolCall").Call(Id("id"), Id("title"), Id("args").Op("..."))), + ) + f.Line() + // StartEditToolCall: sets kind=edit, status=pending, locations=[{path}], rawInput={path, content} + f.Comment("StartEditToolCall constructs a 'tool_call' update for editing content: kind=edit, status=pending, locations=[{path}], rawInput={path, content}.") + f.Func().Id("StartEditToolCall").Params( + Id("id").Id("ToolCallId"), + Id("title").String(), + Id("path").String(), + Id("content").Any(), + Id("opts").Op("...").Id("ToolCallStartOpt"), + ).Id("SessionUpdate").Block( + Id("base").Op(":=").Index().Id("ToolCallStartOpt").Values( + Id("WithStartKind").Call(Id("ToolKindEdit")), + Id("WithStartStatus").Call(Id("ToolCallStatusPending")), + Id("WithStartLocations").Call( + Index().Id("ToolCallLocation").Values( + Id("ToolCallLocation").Values(Dict{Id("Path"): Id("path")}), + ), + ), + Id("WithStartRawInput").Call(Map(String()).Any().Values(Dict{Lit("path"): Id("path"), Lit("content"): Id("content")})), + ), + Id("args").Op(":=").Id("append").Call(Id("base"), Id("opts").Op("...")), + Return(Id("StartToolCall").Call(Id("id"), Id("title"), Id("args").Op("..."))), + ) + f.Line() + var buf bytes.Buffer if err := f.Render(&buf); err != nil { return err } return os.WriteFile(filepath.Join(outDir, "helpers_gen.go"), buf.Bytes(), 0o644) } + +// Note: isStringConstUnion exists in types emitter; we reference that file-level function diff --git a/go/cmd/generate/internal/emit/jenwrap.go b/go/cmd/generate/internal/emit/jenwrap.go index 23e7123..6691a72 100644 --- a/go/cmd/generate/internal/emit/jenwrap.go +++ b/go/cmd/generate/internal/emit/jenwrap.go @@ -14,6 +14,9 @@ var ( NewFile = jen.NewFile Id = jen.Id Lit = jen.Lit + Func = jen.Func + For = jen.For + Range = jen.Range Return = jen.Return Nil = jen.Nil String = jen.String diff --git a/go/cmd/generate/internal/emit/types.go b/go/cmd/generate/internal/emit/types.go index b56af4e..4591969 100644 --- a/go/cmd/generate/internal/emit/types.go +++ b/go/cmd/generate/internal/emit/types.go @@ -61,20 +61,12 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { f.Const().Defs(defs...) } f.Line() - case name == "ContentBlock": - emitContentBlockJen(f) - case name == "SessionUpdate": - emitSessionUpdateJen(f) - case len(def.AnyOf) > 0 && !def.DocsIgnore: - emitAnyOfUnionJen(f, name, def) - case len(def.OneOf) > 0 && !isStringConstUnion(def) && !def.DocsIgnore: + case len(def.AnyOf) > 0: + emitUnion(f, name, def.AnyOf, false) + case len(def.OneOf) > 0 && !isStringConstUnion(def): // Generic union generation for non-enum oneOf - // Reuse same path as anyOf by treating oneOf as anyOf here - // Temporarily map OneOf into def.AnyOf for emission - tmp := *def - tmp.AnyOf = def.OneOf - tmp.OneOf = nil - emitAnyOfUnionJen(f, name, &tmp) + // Use the same implementation, but require exactly one variant + emitUnion(f, name, def.OneOf, true) case ir.PrimaryType(def) == "object" && len(def.Properties) > 0: st := []Code{} req := map[string]struct{}{} @@ -94,7 +86,11 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { } tag := pk if _, ok := req[pk]; !ok { - tag = pk + ",omitempty" + // Default: omit if empty, except for specific always-present fields + // Ensure InitializeResponse.authMethods is always encoded (even when empty) + if !(name == "InitializeResponse" && pk == "authMethods") { + tag = pk + ",omitempty" + } } st = append(st, Id(field).Add(jenTypeForOptional(prop)).Tag(map[string]string{"json": tag})) } @@ -110,7 +106,8 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { } // validators for selected types - if strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification") || name == "ContentBlock" || name == "SessionUpdate" || name == "ToolCallUpdate" { + // Note: oneOf union wrappers get a generic Validate emitted in emitUnion. + if strings.HasSuffix(name, "Request") || strings.HasSuffix(name, "Response") || strings.HasSuffix(name, "Notification") || name == "ToolCallUpdate" { emitValidateJen(f, name, def) } } @@ -221,43 +218,9 @@ func isStringConstUnion(def *load.Definition) bool { } // emitValidateJen generates validators for selected types (logic unchanged). + func emitValidateJen(f *File, name string, def *load.Definition) { switch name { - case "ContentBlock": - f.Func().Params(Id("c").Op("*").Id("ContentBlock")).Id("Validate").Params().Params(Error()).Block( - Switch(Id("c").Dot("Type")).Block( - Case(Lit("text")).Block(If(Id("c").Dot("Text").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.text missing"))))), - Case(Lit("image")).Block(If(Id("c").Dot("Image").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.image missing"))))), - Case(Lit("audio")).Block(If(Id("c").Dot("Audio").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.audio missing"))))), - Case(Lit("resource_link")).Block(If(Id("c").Dot("ResourceLink").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.resource_link missing"))))), - Case(Lit("resource")).Block(If(Id("c").Dot("Resource").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("contentblock.resource missing"))))), - ), - Return(Nil()), - ) - return - case "ToolCallContent": - f.Func().Params(Id("t").Op("*").Id("ToolCallContent")).Id("Validate").Params().Params(Error()).Block( - Switch(Id("t").Dot("Type")).Block( - Case(Lit("content")).Block(If(Id("t").Dot("Content").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.content missing"))))), - Case(Lit("diff")).Block(If(Id("t").Dot("Diff").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.diff missing"))))), - Case(Lit("terminal")).Block(If(Id("t").Dot("Terminal").Op("==").Nil()).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolcallcontent.terminal missing"))))), - ), - Return(Nil()), - ) - return - case "SessionUpdate": - f.Func().Params(Id("s").Op("*").Id("SessionUpdate")).Id("Validate").Params().Params(Error()).Block( - Var().Id("count").Int(), - If(Id("s").Dot("UserMessageChunk").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("AgentMessageChunk").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("AgentThoughtChunk").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("ToolCall").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("ToolCallUpdate").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("s").Dot("Plan").Op("!=").Nil()).Block(Id("count").Op("++")), - If(Id("count").Op("!=").Lit(1)).Block(Return(Qual("fmt", "Errorf").Call(Lit("sessionupdate must have exactly one variant set")))), - Return(Nil()), - ) - return case "ToolCallUpdate": f.Func().Params(Id("t").Op("*").Id("ToolCallUpdate")).Id("Validate").Params().Params(Error()).Block( If(Id("t").Dot("ToolCallId").Op("==").Lit("")).Block(Return(Qual("fmt", "Errorf").Call(Lit("toolCallId is required")))), @@ -413,320 +376,199 @@ func jenTypeForOptional(d *load.Definition) Code { return jenTypeFor(d) } -// Specialized emitters copied from original (unchanged behavior). -func emitContentBlockJen(f *File) { - f.Type().Id("ResourceLinkContent").Struct( - Id("Annotations").Any().Tag(map[string]string{"json": "annotations,omitempty"}), - Id("Description").Op("*").String().Tag(map[string]string{"json": "description,omitempty"}), - Id("MimeType").Op("*").String().Tag(map[string]string{"json": "mimeType,omitempty"}), - Id("Name").String().Tag(map[string]string{"json": "name"}), - Id("Size").Op("*").Int64().Tag(map[string]string{"json": "size,omitempty"}), - Id("Title").Op("*").String().Tag(map[string]string{"json": "title,omitempty"}), - Id("Uri").String().Tag(map[string]string{"json": "uri"}), - ) - f.Line() - f.Type().Id("ContentBlock").Struct( - Id("Type").String().Tag(map[string]string{"json": "type"}), - Id("Text").Op("*").Id("TextContent").Tag(map[string]string{"json": "-"}), - Id("Image").Op("*").Id("ImageContent").Tag(map[string]string{"json": "-"}), - Id("Audio").Op("*").Id("AudioContent").Tag(map[string]string{"json": "-"}), - Id("ResourceLink").Op("*").Id("ResourceLinkContent").Tag(map[string]string{"json": "-"}), - Id("Resource").Op("*").Id("EmbeddedResource").Tag(map[string]string{"json": "-"}), - ) - f.Line() - f.Func().Params(Id("c").Op("*").Id("ContentBlock")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - Var().Id("probe").Struct(Id("Type").String().Tag(map[string]string{"json": "type"})), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("probe")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Type").Op("=").Id("probe").Dot("Type"), - Switch(Id("probe").Dot("Type")).Block( - Case(Lit("text")).Block( - Var().Id("v").Id("TextContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Text").Op("=").Op("&").Id("v"), - ), - Case(Lit("image")).Block( - Var().Id("v").Id("ImageContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Image").Op("=").Op("&").Id("v"), - ), - Case(Lit("audio")).Block( - Var().Id("v").Id("AudioContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Audio").Op("=").Op("&").Id("v"), - ), - Case(Lit("resource_link")).Block( - Var().Id("v").Id("ResourceLinkContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("ResourceLink").Op("=").Op("&").Id("v"), - ), - Case(Lit("resource")).Block( - Var().Id("v").Id("EmbeddedResource"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("c").Dot("Resource").Op("=").Op("&").Id("v"), - ), - ), - Return(Nil()), - ) - f.Func().Params(Id("c").Id("ContentBlock")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( - Switch(Id("c").Dot("Type")).Block( - Case(Lit("text")).Block( - If(Id("c").Dot("Text").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("text"), - Lit("text"): Id("c").Dot("Text").Dot("Text"), - }))), - ), - ), - Case(Lit("image")).Block( - If(Id("c").Dot("Image").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("image"), - Lit("data"): Id("c").Dot("Image").Dot("Data"), - Lit("mimeType"): Id("c").Dot("Image").Dot("MimeType"), - Lit("uri"): Id("c").Dot("Image").Dot("Uri"), - }))), - ), - ), - Case(Lit("audio")).Block( - If(Id("c").Dot("Audio").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("audio"), - Lit("data"): Id("c").Dot("Audio").Dot("Data"), - Lit("mimeType"): Id("c").Dot("Audio").Dot("MimeType"), - }))), - ), - ), - Case(Lit("resource_link")).Block( - If(Id("c").Dot("ResourceLink").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("resource_link"), - Lit("name"): Id("c").Dot("ResourceLink").Dot("Name"), - Lit("uri"): Id("c").Dot("ResourceLink").Dot("Uri"), - Lit("description"): Id("c").Dot("ResourceLink").Dot("Description"), - Lit("mimeType"): Id("c").Dot("ResourceLink").Dot("MimeType"), - Lit("size"): Id("c").Dot("ResourceLink").Dot("Size"), - Lit("title"): Id("c").Dot("ResourceLink").Dot("Title"), - }))), - ), - ), - Case(Lit("resource")).Block( - If(Id("c").Dot("Resource").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("type"): Lit("resource"), - Lit("resource"): Id("c").Dot("Resource").Dot("Resource"), - }))), - ), - ), - ), - Return(Index().Byte().Values(), Nil()), - ) - f.Line() -} - -func emitToolCallContentJen(f *File) { - f.Type().Id("DiffContent").Struct( - Id("NewText").String().Tag(map[string]string{"json": "newText"}), - Id("OldText").Op("*").String().Tag(map[string]string{"json": "oldText,omitempty"}), - Id("Path").String().Tag(map[string]string{"json": "path"}), - ) - f.Type().Id("TerminalRef").Struct(Id("TerminalId").String().Tag(map[string]string{"json": "terminalId"})) - f.Line() - f.Type().Id("ToolCallContent").Struct( - Id("Type").String().Tag(map[string]string{"json": "type"}), - Id("Content").Op("*").Id("ContentBlock").Tag(map[string]string{"json": "-"}), - Id("Diff").Op("*").Id("DiffContent").Tag(map[string]string{"json": "-"}), - Id("Terminal").Op("*").Id("TerminalRef").Tag(map[string]string{"json": "-"}), - ) - f.Line() - f.Func().Params(Id("t").Op("*").Id("ToolCallContent")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - Var().Id("probe").Struct(Id("Type").String().Tag(map[string]string{"json": "type"})), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("probe")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("t").Dot("Type").Op("=").Id("probe").Dot("Type"), - Switch(Id("probe").Dot("Type")).Block( - Case(Lit("content")).Block( - Var().Id("v").Struct( - Id("Type").String().Tag(map[string]string{"json": "type"}), - Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"}), - ), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("t").Dot("Content").Op("=").Op("&").Id("v").Dot("Content"), - ), - Case(Lit("diff")).Block( - Var().Id("v").Id("DiffContent"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("t").Dot("Diff").Op("=").Op("&").Id("v"), - ), - Case(Lit("terminal")).Block( - Var().Id("v").Id("TerminalRef"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("t").Dot("Terminal").Op("=").Op("&").Id("v"), - ), - ), - Return(Nil()), - ) - f.Line() -} - -func emitEmbeddedResourceResourceJen(f *File) { - f.Type().Id("EmbeddedResourceResource").Struct( - Id("TextResourceContents").Op("*").Id("TextResourceContents").Tag(map[string]string{"json": "-"}), - Id("BlobResourceContents").Op("*").Id("BlobResourceContents").Tag(map[string]string{"json": "-"}), - ) - f.Line() - f.Func().Params(Id("e").Op("*").Id("EmbeddedResourceResource")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - If(List(Id("_"), Id("ok")).Op(":=").Id("m").Index(Lit("text")), Id("ok")).Block( - Var().Id("v").Id("TextResourceContents"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("e").Dot("TextResourceContents").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - If(List(Id("_"), Id("ok2")).Op(":=").Id("m").Index(Lit("blob")), Id("ok2")).Block( - Var().Id("v").Id("BlobResourceContents"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("e").Dot("BlobResourceContents").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Return(Nil()), - ) - f.Line() -} - // emitAvailableCommandInputJen generates a concrete variant type for anyOf and a thin union wrapper // that supports JSON unmarshal by probing object shape. Currently the schema defines one variant // (title: UnstructuredCommandInput) with a required 'hint' field. -func emitAnyOfUnionJen(f *File, name string, def *load.Definition) { - // Collect variant names and generate inline structs for object variants if needed +func emitUnion(f *File, name string, defs []*load.Definition, exactlyOne bool) { type variantInfo struct { - fieldName string - typeName string - required []string - isObject bool - consts map[string]any + fieldName string + typeName string + required []string + isObject bool + discValue string + constPairs [][2]string + isNull bool } variants := []variantInfo{} - for idx, v := range def.AnyOf { + discKey := "" + // discover discriminator key if present (any const property) + for _, v := range defs { + if v == nil { + continue + } + for k, pd := range v.Properties { + if pd != nil && pd.Const != nil { + discKey = k + break + } + } + if discKey != "" { + break + } + } + for idx, v := range defs { if v == nil { continue } + // Detect null-only variant + isNull := false + if s, ok := v.Type.(string); ok && s == "null" { + isNull = true + } tname := v.Title if tname == "" { - if v.Ref != "" { - // derive from ref path - if strings.HasPrefix(v.Ref, "#/$defs/") { - tname = v.Ref[len("#/$defs/"):] - } + if v.Ref != "" && strings.HasPrefix(v.Ref, "#/$defs/") { + tname = v.Ref[len("#/$defs/"):] } else { - // Derive from const outcome/type if present - if out, ok := v.Properties["outcome"]; ok && out != nil && out.Const != nil { - s := fmt.Sprint(out.Const) - tname = name + util.ToExportedField(s) - } else if typ, ok2 := v.Properties["type"]; ok2 && typ != nil && typ.Const != nil { - s := fmt.Sprint(typ.Const) - tname = name + util.ToExportedField(s) - } else { + if discKey != "" { + if pd := v.Properties[discKey]; pd != nil && pd.Const != nil { + s := fmt.Sprint(pd.Const) + tname = name + util.ToExportedField(s) + } + } + if tname == "" { tname = name + fmt.Sprintf("Variant%d", idx+1) } } } fieldName := tname - if out, ok := v.Properties["outcome"]; ok && out != nil && out.Const != nil { - s := fmt.Sprint(out.Const) - fieldName = util.ToExportedField(s) - } else if typ, ok2 := v.Properties["type"]; ok2 && typ != nil && typ.Const != nil { - s := fmt.Sprint(typ.Const) - fieldName = util.ToExportedField(s) - } - // If this variant is an inline object, generate its struct - isObj := len(v.Properties) > 0 - if isObj && v.Ref == "" { - st := []Code{} - req := map[string]struct{}{} - for _, r := range v.Required { - req[r] = struct{}{} + dv := "" + if discKey != "" { + if pd := v.Properties[discKey]; pd != nil && pd.Const != nil { + s := fmt.Sprint(pd.Const) + fieldName = util.ToExportedField(s) + dv = s } - pkeys := make([]string, 0, len(v.Properties)) - for pk := range v.Properties { - pkeys = append(pkeys, pk) - } - sort.Strings(pkeys) - // Variant doc comment - if v.Description != "" { - f.Comment(util.SanitizeComment(v.Description)) + } + isObj := len(v.Properties) > 0 + // collect const properties (e.g., type, outcome) + consts := [][2]string{} + for pk, pd := range v.Properties { + if pd != nil && pd.Const != nil { + if s, ok := pd.Const.(string); ok { + consts = append(consts, [2]string{pk, s}) + } } - for _, pk := range pkeys { - pDef := v.Properties[pk] - field := util.ToExportedField(pk) - if pDef.Description != "" { - st = append(st, Comment(util.SanitizeComment(pDef.Description))) + } + if (isObj || isNull) && v.Ref == "" { + st := []Code{} + if !isNull { + req := map[string]struct{}{} + for _, r := range v.Required { + req[r] = struct{}{} } - tag := pk - if _, ok := req[pk]; !ok { - tag = pk + ",omitempty" + pkeys := make([]string, 0, len(v.Properties)) + for pk := range v.Properties { + pkeys = append(pkeys, pk) + } + sort.Strings(pkeys) + if v.Description != "" { + f.Comment(util.SanitizeComment(v.Description)) + } + for _, pk := range pkeys { + pDef := v.Properties[pk] + field := util.ToExportedField(pk) + if pDef.Description != "" { + st = append(st, Comment(util.SanitizeComment(pDef.Description))) + } + tag := pk + if _, ok := req[pk]; !ok { + tag = pk + ",omitempty" + } + st = append(st, Id(field).Add(jenTypeForOptional(pDef)).Tag(map[string]string{"json": tag})) } - st = append(st, Id(field).Add(jenTypeForOptional(pDef)).Tag(map[string]string{"json": tag})) } f.Type().Id(tname).Struct(st...) f.Line() } - // Collect const properties for detection - consts := map[string]any{} - for pk, pd := range v.Properties { - if pd != nil && pd.Const != nil { - consts[pk] = pd.Const - } - } - variants = append(variants, variantInfo{fieldName: fieldName, typeName: tname, required: v.Required, isObject: isObj, consts: consts}) + variants = append(variants, variantInfo{fieldName: fieldName, typeName: tname, required: v.Required, isObject: isObj, discValue: dv, constPairs: consts, isNull: isNull}) } - // Union wrapper + // wrapper st := []Code{} for _, vi := range variants { st = append(st, Id(vi.fieldName).Op("*").Id(vi.typeName).Tag(map[string]string{"json": "-"})) } f.Type().Id(name).Struct(st...) f.Line() - // Unmarshal: prefer required-field presence checks for object variants, then fallback to try-unmarshal for all + // Unmarshal f.Func().Params(Id("u").Op("*").Id(name)).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().BlockFunc(func(g *Group) { + // Handle literal null if a null-only variant exists + { + varNullHandled := false + for _, vi := range variants { + if vi.isNull { + // emit once for the first null variant + if !varNullHandled { + g.If(Id("string").Call(Id("b")).Op("==").Lit("null")).Block( + Var().Id("v").Id(vi.typeName), + Id("u").Dot(vi.fieldName).Op("=").Op("&").Id("v"), + Return(Nil()), + ) + varNullHandled = true + } + } + } + } g.Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage") g.If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))) - // Try required-field detection for object variants + // Prefer discriminator-based dispatch when available (e.g. "type", "outcome") + if discKey != "" { + g.BlockFunc(func(h *Group) { + h.Var().Id("disc").String() + h.If(List(Id("v"), Id("ok")).Op(":=").Id("m").Index(Lit(discKey)), Id("ok")).Block( + Qual("encoding/json", "Unmarshal").Call(Id("v"), Op("&").Id("disc")), + ) + h.Switch(Id("disc")).BlockFunc(func(sw *Group) { + for _, vi := range variants { + if vi.discValue != "" { + sw.Case(Lit(vi.discValue)).Block( + Var().Id("v").Id(vi.typeName), + If(Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")).Op("!=").Nil()).Block(Return(Qual("errors", "New").Call(Lit("invalid variant payload")))), + Id("u").Dot(vi.fieldName).Op("=").Op("&").Id("v"), + Return(Nil()), + ) + } + } + }) + }) + } + // Special-case: EmbeddedResourceResource variants distinguished by keys + if name == "EmbeddedResourceResource" { + g.If(List(Id("_"), Id("ok")).Op(":=").Id("m").Index(Lit("text")), Id("ok")).Block( + Var().Id("v").Id("TextResourceContents"), + If(Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")).Op("!=").Nil()).Block(Return(Qual("errors", "New").Call(Lit("invalid variant payload")))), + Id("u").Dot("TextResourceContents").Op("=").Op("&").Id("v"), + Return(Nil()), + ) + g.If(List(Id("_"), Id("ok")).Op(":=").Id("m").Index(Lit("blob")), Id("ok")).Block( + Var().Id("v").Id("BlobResourceContents"), + If(Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")).Op("!=").Nil()).Block(Return(Qual("errors", "New").Call(Lit("invalid variant payload")))), + Id("u").Dot("BlobResourceContents").Op("=").Op("&").Id("v"), + Return(Nil()), + ) + } + // required-key match for _, vi := range variants { if vi.isObject && len(vi.required) > 0 { - stmts := []Code{ - Var().Id("v").Id(vi.typeName), - Var().Id("match").Bool().Op("=").Lit(true), - } - for _, rk := range vi.required { - stmts = append(stmts, If(List(Id("_"), Id("ok")).Op(":=").Id("m").Index(Lit(rk)), Op("!").Id("ok")).Block(Id("match").Op("=").Lit(false))) - } - // Check const-valued fields - for ck, cv := range vi.consts { - // read m[ck] and compare to const value (stringify for simplicity) - stmts = append(stmts, - Var().Id("raw").Qual("encoding/json", "RawMessage"), Var().Id("ok").Bool(), - List(Id("raw"), Id("ok")).Op("=").Id("m").Index(Lit(ck)), - If(Op("!").Id("ok")).Block(Id("match").Op("=").Lit(false)), - If(Id("ok")).Block( - Var().Id("tmp").Any(), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("raw"), Op("&").Id("tmp")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - If(Qual("fmt", "Sprint").Call(Id("tmp")).Op("!=").Qual("fmt", "Sprint").Call(Lit(cv))).Block(Id("match").Op("=").Lit(false)), - ), + g.BlockFunc(func(h *Group) { + h.Var().Id("v").Id(vi.typeName) + h.Var().Id("match").Bool().Op("=").Lit(true) + for _, rk := range vi.required { + h.If(List(Id("_"), Id("ok")).Op(":=").Id("m").Index(Lit(rk)), Op("!").Id("ok")).Block(Id("match").Op("=").Lit(false)) + } + h.If(Id("match")).Block( + If(Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")).Op("!=").Nil()).Block(Return(Qual("errors", "New").Call(Lit("invalid variant payload")))), + Id("u").Dot(vi.fieldName).Op("=").Op("&").Id("v"), + Return(Nil()), ) - } - stmts = append(stmts, If(Id("match")).Block( - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("u").Dot(vi.fieldName).Op("=").Op("&").Id("v"), - Return(Nil()), - )) - g.Block(stmts...) + }) } } - // Fallback: try to unmarshal into each variant sequentially + // fallback: try decode sequentially for _, vi := range variants { g.Block( Var().Id("v").Id(vi.typeName), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("==").Nil()).Block( + If(Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")).Op("==").Nil()).Block( Id("u").Dot(vi.fieldName).Op("=").Op("&").Id("v"), Return(Nil()), ), @@ -734,151 +576,112 @@ func emitAnyOfUnionJen(f *File, name string, def *load.Definition) { } g.Return(Nil()) }) - // Marshal: pick first non-nil + // Marshal f.Func().Params(Id("u").Id(name)).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).BlockFunc(func(g *Group) { for _, vi := range variants { - g.If(Id("u").Dot(vi.fieldName).Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Op("*").Id("u").Dot(vi.fieldName))), - ) + g.If(Id("u").Dot(vi.fieldName).Op("!=").Nil()).BlockFunc(func(gg *Group) { + // Null-only variant encodes to JSON null + if vi.isNull { + gg.Return(Qual("encoding/json", "Marshal").Call(Nil())) + } + // Marshal variant to map for discriminant injection and shaping + gg.Var().Id("m").Map(String()).Any() + gg.List(Id("_b"), Id("_e")).Op(":=").Qual("encoding/json", "Marshal").Call(Op("*").Id("u").Dot(vi.fieldName)) + gg.If(Id("_e").Op("!=").Nil()).Block(Return(Index().Byte().Values(), Id("_e"))) + gg.If(Qual("encoding/json", "Unmarshal").Call(Id("_b"), Op("&").Id("m")).Op("!=").Nil()).Block(Return(Index().Byte().Values(), Qual("errors", "New").Call(Lit("invalid variant payload")))) + // Inject const discriminants + if len(vi.constPairs) > 0 { + for _, kv := range vi.constPairs { + gg.Id("m").Index(Lit(kv[0])).Op("=").Lit(kv[1]) + } + } + // Special shaping for ContentBlock variants to preserve exact wire JSON + if name == "ContentBlock" { + switch vi.discValue { + case "text": + gg.Block( + Var().Id("nm").Map(String()).Any(), + Id("nm").Op("=").Make(Map(String()).Any()), + Id("nm").Index(Lit("type")).Op("=").Lit("text"), + Id("nm").Index(Lit("text")).Op("=").Id("m").Index(Lit("text")), + Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), + ) + case "image": + gg.Block( + Var().Id("nm").Map(String()).Any(), + Id("nm").Op("=").Make(Map(String()).Any()), + Id("nm").Index(Lit("type")).Op("=").Lit("image"), + Id("nm").Index(Lit("data")).Op("=").Id("m").Index(Lit("data")), + Id("nm").Index(Lit("mimeType")).Op("=").Id("m").Index(Lit("mimeType")), + // Only include uri if present; do not emit null + If(List(Id("_v"), Id("_ok")).Op(":=").Id("m").Index(Lit("uri")), Id("_ok")).Block( + Id("nm").Index(Lit("uri")).Op("=").Id("_v"), + ), + Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), + ) + case "audio": + gg.Block( + Var().Id("nm").Map(String()).Any(), + Id("nm").Op("=").Make(Map(String()).Any()), + Id("nm").Index(Lit("type")).Op("=").Lit("audio"), + Id("nm").Index(Lit("data")).Op("=").Id("m").Index(Lit("data")), + Id("nm").Index(Lit("mimeType")).Op("=").Id("m").Index(Lit("mimeType")), + Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), + ) + case "resource_link": + gg.BlockFunc(func(b *Group) { + b.Var().Id("nm").Map(String()).Any() + b.Id("nm").Op("=").Make(Map(String()).Any()) + b.Id("nm").Index(Lit("type")).Op("=").Lit("resource_link") + b.Id("nm").Index(Lit("name")).Op("=").Id("m").Index(Lit("name")) + b.Id("nm").Index(Lit("uri")).Op("=").Id("m").Index(Lit("uri")) + // Only include optional keys if present + b.If(List(Id("v1"), Id("ok1")).Op(":=").Id("m").Index(Lit("description")), Id("ok1")).Block( + Id("nm").Index(Lit("description")).Op("=").Id("v1"), + ) + b.If(List(Id("v2"), Id("ok2")).Op(":=").Id("m").Index(Lit("mimeType")), Id("ok2")).Block( + Id("nm").Index(Lit("mimeType")).Op("=").Id("v2"), + ) + b.If(List(Id("v3"), Id("ok3")).Op(":=").Id("m").Index(Lit("size")), Id("ok3")).Block( + Id("nm").Index(Lit("size")).Op("=").Id("v3"), + ) + b.If(List(Id("v4"), Id("ok4")).Op(":=").Id("m").Index(Lit("title")), Id("ok4")).Block( + Id("nm").Index(Lit("title")).Op("=").Id("v4"), + ) + b.Return(Qual("encoding/json", "Marshal").Call(Id("nm"))) + }) + case "resource": + gg.Block( + Var().Id("nm").Map(String()).Any(), + Id("nm").Op("=").Make(Map(String()).Any()), + Id("nm").Index(Lit("type")).Op("=").Lit("resource"), + Id("nm").Index(Lit("resource")).Op("=").Id("m").Index(Lit("resource")), + Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), + ) + } + } + // default: remarshal possibly with injected discriminant + if name != "ContentBlock" { + gg.Return(Qual("encoding/json", "Marshal").Call(Id("m"))) + } + }) } g.Return(Index().Byte().Values(), Nil()) }) f.Line() -} -func emitSessionUpdateJen(f *File) { - f.Type().Id("SessionUpdateUserMessageChunk").Struct(Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"})) - f.Type().Id("SessionUpdateAgentMessageChunk").Struct(Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"})) - f.Type().Id("SessionUpdateAgentThoughtChunk").Struct(Id("Content").Id("ContentBlock").Tag(map[string]string{"json": "content"})) - f.Type().Id("SessionUpdateToolCall").Struct( - Id("Content").Index().Id("ToolCallContent").Tag(map[string]string{"json": "content,omitempty"}), - Id("Kind").Id("ToolKind").Tag(map[string]string{"json": "kind,omitempty"}), - Id("Locations").Index().Id("ToolCallLocation").Tag(map[string]string{"json": "locations,omitempty"}), - Id("RawInput").Any().Tag(map[string]string{"json": "rawInput,omitempty"}), - Id("RawOutput").Any().Tag(map[string]string{"json": "rawOutput,omitempty"}), - Id("Status").Id("ToolCallStatus").Tag(map[string]string{"json": "status,omitempty"}), - Id("Title").String().Tag(map[string]string{"json": "title"}), - Id("ToolCallId").Id("ToolCallId").Tag(map[string]string{"json": "toolCallId"}), - ) - f.Type().Id("SessionUpdateToolCallUpdate").Struct( - Id("Content").Index().Id("ToolCallContent").Tag(map[string]string{"json": "content,omitempty"}), - Id("Kind").Any().Tag(map[string]string{"json": "kind,omitempty"}), - Id("Locations").Index().Id("ToolCallLocation").Tag(map[string]string{"json": "locations,omitempty"}), - Id("RawInput").Any().Tag(map[string]string{"json": "rawInput,omitempty"}), - Id("RawOutput").Any().Tag(map[string]string{"json": "rawOutput,omitempty"}), - Id("Status").Any().Tag(map[string]string{"json": "status,omitempty"}), - Id("Title").Op("*").String().Tag(map[string]string{"json": "title,omitempty"}), - Id("ToolCallId").Id("ToolCallId").Tag(map[string]string{"json": "toolCallId"}), - ) - f.Type().Id("SessionUpdatePlan").Struct(Id("Entries").Index().Id("PlanEntry").Tag(map[string]string{"json": "entries"})) - f.Line() - f.Type().Id("SessionUpdate").Struct( - Id("UserMessageChunk").Op("*").Id("SessionUpdateUserMessageChunk").Tag(map[string]string{"json": "-"}), - Id("AgentMessageChunk").Op("*").Id("SessionUpdateAgentMessageChunk").Tag(map[string]string{"json": "-"}), - Id("AgentThoughtChunk").Op("*").Id("SessionUpdateAgentThoughtChunk").Tag(map[string]string{"json": "-"}), - Id("ToolCall").Op("*").Id("SessionUpdateToolCall").Tag(map[string]string{"json": "-"}), - Id("ToolCallUpdate").Op("*").Id("SessionUpdateToolCallUpdate").Tag(map[string]string{"json": "-"}), - Id("Plan").Op("*").Id("SessionUpdatePlan").Tag(map[string]string{"json": "-"}), - ) - f.Func().Params(Id("s").Op("*").Id("SessionUpdate")).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().Block( - Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Var().Id("kind").String(), - If(List(Id("v"), Id("ok")).Op(":=").Id("m").Index(Lit("sessionUpdate")), Id("ok")).Block( - Qual("encoding/json", "Unmarshal").Call(Id("v"), Op("&").Id("kind")), - ), - Switch(Id("kind")).Block( - Case(Lit("user_message_chunk")).Block( - Var().Id("v").Id("SessionUpdateUserMessageChunk"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("UserMessageChunk").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("agent_message_chunk")).Block( - Var().Id("v").Id("SessionUpdateAgentMessageChunk"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("AgentMessageChunk").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("agent_thought_chunk")).Block( - Var().Id("v").Id("SessionUpdateAgentThoughtChunk"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("AgentThoughtChunk").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("tool_call")).Block( - Var().Id("v").Id("SessionUpdateToolCall"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("ToolCall").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("tool_call_update")).Block( - Var().Id("v").Id("SessionUpdateToolCallUpdate"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("ToolCallUpdate").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - Case(Lit("plan")).Block( - Var().Id("v").Id("SessionUpdatePlan"), - If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("v")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))), - Id("s").Dot("Plan").Op("=").Op("&").Id("v"), - Return(Nil()), - ), - ), - Return(Nil()), - ) - f.Func().Params(Id("s").Id("SessionUpdate")).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).Block( - If(Id("s").Dot("UserMessageChunk").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("user_message_chunk"), - Lit("content"): Id("s").Dot("UserMessageChunk").Dot("Content"), - }))), - ), - If(Id("s").Dot("AgentMessageChunk").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("agent_message_chunk"), - Lit("content"): Id("s").Dot("AgentMessageChunk").Dot("Content"), - }))), - ), - If(Id("s").Dot("AgentThoughtChunk").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("agent_thought_chunk"), - Lit("content"): Id("s").Dot("AgentThoughtChunk").Dot("Content"), - }))), - ), - If(Id("s").Dot("ToolCall").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("tool_call"), - Lit("content"): Id("s").Dot("ToolCall").Dot("Content"), - Lit("kind"): Id("s").Dot("ToolCall").Dot("Kind"), - Lit("locations"): Id("s").Dot("ToolCall").Dot("Locations"), - Lit("rawInput"): Id("s").Dot("ToolCall").Dot("RawInput"), - Lit("rawOutput"): Id("s").Dot("ToolCall").Dot("RawOutput"), - Lit("status"): Id("s").Dot("ToolCall").Dot("Status"), - Lit("title"): Id("s").Dot("ToolCall").Dot("Title"), - Lit("toolCallId"): Id("s").Dot("ToolCall").Dot("ToolCallId"), - }))), - ), - If(Id("s").Dot("ToolCallUpdate").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("tool_call_update"), - Lit("content"): Id("s").Dot("ToolCallUpdate").Dot("Content"), - Lit("kind"): Id("s").Dot("ToolCallUpdate").Dot("Kind"), - Lit("locations"): Id("s").Dot("ToolCallUpdate").Dot("Locations"), - Lit("rawInput"): Id("s").Dot("ToolCallUpdate").Dot("RawInput"), - Lit("rawOutput"): Id("s").Dot("ToolCallUpdate").Dot("RawOutput"), - Lit("status"): Id("s").Dot("ToolCallUpdate").Dot("Status"), - Lit("title"): Id("s").Dot("ToolCallUpdate").Dot("Title"), - Lit("toolCallId"): Id("s").Dot("ToolCallUpdate").Dot("ToolCallId"), - }))), - ), - If(Id("s").Dot("Plan").Op("!=").Nil()).Block( - Return(Qual("encoding/json", "Marshal").Call(Map(String()).Any().Values(Dict{ - Lit("sessionUpdate"): Lit("plan"), - Lit("entries"): Id("s").Dot("Plan").Dot("Entries"), - }))), - ), - Return(Index().Byte().Values(), Nil()), - ) - f.Line() + // Generic validator for oneOf unions: exactly one variant must be set + if exactlyOne { + f.Func().Params(Id("u").Op("*").Id(name)).Id("Validate").Params().Params(Error()).BlockFunc(func(g *Group) { + g.Var().Id("count").Int() + for _, vi := range variants { + g.If(Id("u").Dot(vi.fieldName).Op("!=").Nil()).Block(Id("count").Op("++")) + } + g.If(Id("count").Op("!=").Lit(1)).Block( + Return(Qual("errors", "New").Call(Lit(name + " must have exactly one variant set"))), + ) + g.Return(Nil()) + }) + f.Line() + } } diff --git a/go/example/agent/main.go b/go/example/agent/main.go index 9084879..aa7cb51 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -90,11 +90,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { // disclaimer: stream a demo notice so clients see it's the example agent if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{ - AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ - Content: acp.TextBlock("ACP Go Example Agent — demo only (no AI model)."), - }, - }, + Update: acp.UpdateAgentMessageText("ACP Go Example Agent — demo only (no AI model)."), }); err != nil { return err } @@ -104,11 +100,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { // initial message chunk if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{ - AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ - Content: acp.TextBlock("I'll help you with that. Let me start by reading some files to understand the current situation."), - }, - }, + Update: acp.UpdateAgentMessageText("I'll help you with that. Let me start by reading some files to understand the current situation."), }); err != nil { return err } @@ -119,14 +111,14 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { // tool call without permission if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{ToolCall: &acp.SessionUpdateToolCall{ - ToolCallId: acp.ToolCallId("call_1"), - Title: "Reading project files", - Kind: acp.ToolKindRead, - Status: acp.ToolCallStatusPending, - Locations: []acp.ToolCallLocation{{Path: "/project/README.md"}}, - RawInput: map[string]any{"path": "/project/README.md"}, - }}, + Update: acp.StartToolCall( + acp.ToolCallId("call_1"), + "Reading project files", + acp.WithStartKind(acp.ToolKindRead), + acp.WithStartStatus(acp.ToolCallStatusPending), + acp.WithStartLocations([]acp.ToolCallLocation{{Path: "/project/README.md"}}), + acp.WithStartRawInput(map[string]any{"path": "/project/README.md"}), + ), }); err != nil { return err } @@ -137,15 +129,12 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { // update tool call completed if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{ToolCallUpdate: &acp.SessionUpdateToolCallUpdate{ - ToolCallId: acp.ToolCallId("call_1"), - Status: "completed", - Content: []acp.ToolCallContent{ - acp.ToolContent( - acp.TextBlock("# My Project\n\nThis is a sample project...")), - }, - RawOutput: map[string]any{"content": "# My Project\n\nThis is a sample project..."}, - }}, + Update: acp.UpdateToolCall( + acp.ToolCallId("call_1"), + acp.WithUpdateStatus(acp.ToolCallStatusCompleted), + acp.WithUpdateContent([]acp.ToolCallContent{acp.ToolContent(acp.TextBlock("# My Project\n\nThis is a sample project..."))}), + acp.WithUpdateRawOutput(map[string]any{"content": "# My Project\n\nThis is a sample project..."}), + ), }); err != nil { return err } @@ -156,11 +145,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { // more text if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{ - AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ - Content: acp.TextBlock(" Now I understand the project structure. I need to make some changes to improve it."), - }, - }, + Update: acp.UpdateAgentMessageText(" Now I understand the project structure. I need to make some changes to improve it."), }); err != nil { return err } @@ -171,14 +156,14 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { // tool call requiring permission if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{ToolCall: &acp.SessionUpdateToolCall{ - ToolCallId: acp.ToolCallId("call_2"), - Title: "Modifying critical configuration file", - Kind: acp.ToolKindEdit, - Status: acp.ToolCallStatusPending, - Locations: []acp.ToolCallLocation{{Path: "/project/config.json"}}, - RawInput: map[string]any{"path": "/project/config.json", "content": "{\"database\": {\"host\": \"new-host\"}}"}, - }}, + Update: acp.StartToolCall( + acp.ToolCallId("call_2"), + "Modifying critical configuration file", + acp.WithStartKind(acp.ToolKindEdit), + acp.WithStartStatus(acp.ToolCallStatusPending), + acp.WithStartLocations([]acp.ToolCallLocation{{Path: "/project/config.json"}}), + acp.WithStartRawInput(map[string]any{"path": "/project/config.json", "content": "{\"database\": {\"host\": \"new-host\"}}"}), + ), }); err != nil { return err } @@ -214,12 +199,12 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { case "allow": if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{ToolCallUpdate: &acp.SessionUpdateToolCallUpdate{ - ToolCallId: acp.ToolCallId("call_2"), - Status: acp.Ptr(acp.ToolCallStatusCompleted), - RawOutput: map[string]any{"success": true, "message": "Configuration updated"}, - Title: acp.Ptr("Modifying critical configuration file"), - }}, + Update: acp.UpdateToolCall( + acp.ToolCallId("call_2"), + acp.WithUpdateStatus(acp.ToolCallStatusCompleted), + acp.WithUpdateRawOutput(map[string]any{"success": true, "message": "Configuration updated"}), + acp.WithUpdateTitle("Modifying critical configuration file"), + ), }); err != nil { return err } @@ -228,11 +213,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{ - AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ - Content: acp.TextBlock(" Perfect! I've successfully updated the configuration. The changes have been applied."), - }, - }, + Update: acp.UpdateAgentMessageText(" Perfect! I've successfully updated the configuration. The changes have been applied."), }); err != nil { return err } @@ -242,11 +223,7 @@ func (a *exampleAgent) simulateTurn(ctx context.Context, sid string) error { } if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{ SessionId: acp.SessionId(sid), - Update: acp.SessionUpdate{ - AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ - Content: acp.TextBlock(" I understand you prefer not to make that change. I'll skip the configuration update."), - }, - }, + Update: acp.UpdateAgentMessageText(" I understand you prefer not to make that change. I'll skip the configuration update."), }); err != nil { return err } diff --git a/go/example/claude-code/main.go b/go/example/claude-code/main.go index 296028e..b0bb3e2 100644 --- a/go/example/claude-code/main.go +++ b/go/example/claude-code/main.go @@ -69,10 +69,8 @@ func (c *replClient) SessionUpdate(ctx context.Context, params acp.SessionNotifi switch { case u.AgentMessageChunk != nil: content := u.AgentMessageChunk.Content - if content.Type == "text" && content.Text != nil { + if content.Text != nil { fmt.Printf("[agent] \n%s\n", content.Text.Text) - } else { - fmt.Printf("[agent] %s\n", content.Type) } case u.ToolCall != nil: fmt.Printf("\n🔧 %s (%s)\n", u.ToolCall.Title, u.ToolCall.Status) @@ -82,10 +80,8 @@ func (c *replClient) SessionUpdate(ctx context.Context, params acp.SessionNotifi fmt.Println("[plan update]") case u.AgentThoughtChunk != nil: thought := u.AgentThoughtChunk.Content - if thought.Type == "text" && thought.Text != nil { + if thought.Text != nil { fmt.Printf("[agent_thought_chunk] \n%s\n", thought.Text.Text) - } else { - fmt.Println("[agent_thought_chunk]", "(", thought.Type, ")") } case u.UserMessageChunk != nil: fmt.Println("[user_message_chunk]") diff --git a/go/example/client/main.go b/go/example/client/main.go index 6c6880d..eb59777 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -53,10 +53,8 @@ func (e *exampleClient) SessionUpdate(ctx context.Context, params acp.SessionNot switch { case u.AgentMessageChunk != nil: c := u.AgentMessageChunk.Content - if c.Type == "text" && c.Text != nil { + if c.Text != nil { fmt.Println(c.Text.Text) - } else { - fmt.Printf("[%s]\n", c.Type) } case u.ToolCall != nil: fmt.Printf("\n🔧 %s (%s)\n", u.ToolCall.Title, u.ToolCall.Status) diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index 70c34f5..64154d1 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -72,10 +72,8 @@ func (c *replClient) SessionUpdate(ctx context.Context, params acp.SessionNotifi switch { case u.AgentMessageChunk != nil: content := u.AgentMessageChunk.Content - if content.Type == "text" && content.Text != nil { + if content.Text != nil { fmt.Printf("%s", content.Text.Text) - } else { - fmt.Printf("[agent] %s\n", content.Type) } case u.ToolCall != nil: fmt.Printf("\n🔧 %s (%s)\n", u.ToolCall.Title, u.ToolCall.Status) @@ -85,10 +83,8 @@ func (c *replClient) SessionUpdate(ctx context.Context, params acp.SessionNotifi fmt.Println("[plan update]") case u.AgentThoughtChunk != nil: thought := u.AgentThoughtChunk.Content - if thought.Type == "text" && thought.Text != nil { + if thought.Text != nil { fmt.Printf("[agent_thought_chunk] \n%s\n", thought.Text.Text) - } else { - fmt.Println("[agent_thought_chunk]", "(", thought.Type, ")") } case u.UserMessageChunk != nil: fmt.Println("[user_message_chunk]") diff --git a/go/example_agent_test.go b/go/example_agent_test.go index 0e68180..5f85755 100644 --- a/go/example_agent_test.go +++ b/go/example_agent_test.go @@ -28,24 +28,20 @@ func (a *agentExample) Prompt(ctx context.Context, p PromptRequest) (PromptRespo // Stream an initial agent message. _ = a.conn.SessionUpdate(ctx, SessionNotification{ SessionId: p.SessionId, - Update: SessionUpdate{ - AgentMessageChunk: &SessionUpdateAgentMessageChunk{ - Content: TextBlock("I'll help you with that."), - }, - }, + Update: UpdateAgentMessageText("I'll help you with that."), }) // Announce a tool call. _ = a.conn.SessionUpdate(ctx, SessionNotification{ SessionId: p.SessionId, - Update: SessionUpdate{ToolCall: &SessionUpdateToolCall{ - ToolCallId: ToolCallId("call_1"), - Title: "Modifying configuration", - Kind: ToolKindEdit, - Status: ToolCallStatusPending, - Locations: []ToolCallLocation{{Path: "/project/config.json"}}, - RawInput: map[string]any{"path": "/project/config.json"}, - }}, + Update: StartToolCall( + ToolCallId("call_1"), + "Modifying configuration", + WithStartKind(ToolKindEdit), + WithStartStatus(ToolCallStatusPending), + WithStartLocations([]ToolCallLocation{{Path: "/project/config.json"}}), + WithStartRawInput(map[string]any{"path": "/project/config.json"}), + ), }) // Ask the client for permission to proceed with the change. @@ -69,15 +65,15 @@ func (a *agentExample) Prompt(ctx context.Context, p PromptRequest) (PromptRespo // Mark tool call completed and stream a final message. _ = a.conn.SessionUpdate(ctx, SessionNotification{ SessionId: p.SessionId, - Update: SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ - ToolCallId: ToolCallId("call_1"), - Status: ToolCallStatusCompleted, - RawOutput: map[string]any{"success": true}, - }}, + Update: UpdateToolCall( + ToolCallId("call_1"), + WithUpdateStatus(ToolCallStatusCompleted), + WithUpdateRawOutput(map[string]any{"success": true}), + ), }) _ = a.conn.SessionUpdate(ctx, SessionNotification{ SessionId: p.SessionId, - Update: SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: TextBlock("Done.")}}, + Update: UpdateAgentMessageText("Done."), }) } diff --git a/go/example_client_test.go b/go/example_client_test.go index 6d3f5ba..e90e70c 100644 --- a/go/example_client_test.go +++ b/go/example_client_test.go @@ -34,10 +34,8 @@ func (clientExample) SessionUpdate(ctx context.Context, n SessionNotification) e switch { case u.AgentMessageChunk != nil: c := u.AgentMessageChunk.Content - if c.Type == "text" && c.Text != nil { + if c.Text != nil { fmt.Print(c.Text.Text) - } else { - fmt.Println("[", c.Type, "]") } case u.ToolCall != nil: title := u.ToolCall.Title diff --git a/go/example_gemini_test.go b/go/example_gemini_test.go index 4a7d1c9..97f4ac4 100644 --- a/go/example_gemini_test.go +++ b/go/example_gemini_test.go @@ -20,7 +20,7 @@ func (geminiClient) RequestPermission(ctx context.Context, p RequestPermissionRe func (geminiClient) SessionUpdate(ctx context.Context, n SessionNotification) error { if n.Update.AgentMessageChunk != nil { c := n.Update.AgentMessageChunk.Content - if c.Type == "text" && c.Text != nil { + if c.Text != nil { fmt.Print(c.Text.Text) } } diff --git a/go/helpers_gen.go b/go/helpers_gen.go index 9a493a9..c99f1c6 100644 --- a/go/helpers_gen.go +++ b/go/helpers_gen.go @@ -4,52 +4,46 @@ package acp // TextBlock constructs a text content block. func TextBlock(text string) ContentBlock { - return ContentBlock{ - Text: &TextContent{Text: text}, + return ContentBlock{Text: &ContentBlockText{ + Text: text, Type: "text", - } + }} } // ImageBlock constructs an inline image content block with base64-encoded data. func ImageBlock(data string, mimeType string) ContentBlock { - return ContentBlock{ - Image: &ImageContent{ - Data: data, - MimeType: mimeType, - }, - Type: "image", - } + return ContentBlock{Image: &ContentBlockImage{ + Data: data, + MimeType: mimeType, + Type: "image", + }} } // AudioBlock constructs an inline audio content block with base64-encoded data. func AudioBlock(data string, mimeType string) ContentBlock { - return ContentBlock{ - Audio: &AudioContent{ - Data: data, - MimeType: mimeType, - }, - Type: "audio", - } + return ContentBlock{Audio: &ContentBlockAudio{ + Data: data, + MimeType: mimeType, + Type: "audio", + }} } // ResourceLinkBlock constructs a resource_link content block with a name and URI. func ResourceLinkBlock(name string, uri string) ContentBlock { - return ContentBlock{ - ResourceLink: &ResourceLinkContent{ - Name: name, - Uri: uri, - }, + return ContentBlock{ResourceLink: &ContentBlockResourceLink{ + Name: name, Type: "resource_link", - } + Uri: uri, + }} } // ResourceBlock wraps an embedded resource as a content block. func ResourceBlock(res EmbeddedResource) ContentBlock { var r EmbeddedResource = res - return ContentBlock{ - Resource: &r, + return ContentBlock{Resource: &ContentBlockResource{ + Resource: r.Resource, Type: "resource", - } + }} } // ToolContent wraps a content block as tool-call content. @@ -86,3 +80,313 @@ func ToolTerminalRef(terminalId string) ToolCallContent { func Ptr[T any](v T) *T { return &v } + +// UpdateUserMessage constructs a user_message_chunk update with the given content. +func UpdateUserMessage(content ContentBlock) SessionUpdate { + return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: content}} +} + +// UpdateUserMessageText constructs a user_message_chunk update from text. +func UpdateUserMessageText(text string) SessionUpdate { + return UpdateUserMessage(TextBlock(text)) +} + +// UpdateAgentMessage constructs an agent_message_chunk update with the given content. +func UpdateAgentMessage(content ContentBlock) SessionUpdate { + return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: content}} +} + +// UpdateAgentMessageText constructs an agent_message_chunk update from text. +func UpdateAgentMessageText(text string) SessionUpdate { + return UpdateAgentMessage(TextBlock(text)) +} + +// UpdateAgentThought constructs an agent_thought_chunk update with the given content. +func UpdateAgentThought(content ContentBlock) SessionUpdate { + return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{Content: content}} +} + +// UpdateAgentThoughtText constructs an agent_thought_chunk update from text. +func UpdateAgentThoughtText(text string) SessionUpdate { + return UpdateAgentThought(TextBlock(text)) +} + +// UpdatePlan constructs a plan update with the provided entries. +func UpdatePlan(entries ...PlanEntry) SessionUpdate { + return SessionUpdate{Plan: &SessionUpdatePlan{Entries: entries}} +} + +type ToolCallStartOpt func(tc *SessionUpdateToolCall) + +// StartToolCall constructs a tool_call update with required fields and applies optional modifiers. +func StartToolCall(id ToolCallId, title string, opts ...ToolCallStartOpt) SessionUpdate { + tc := SessionUpdateToolCall{ + Title: title, + ToolCallId: id, + } + for _, opt := range opts { + opt(&tc) + } + return SessionUpdate{ToolCall: &tc} +} + +// WithStartKind sets the kind for a tool_call start update. +func WithStartKind(k ToolKind) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.Kind = k + } +} + +// WithStartStatus sets the status for a tool_call start update. +func WithStartStatus(s ToolCallStatus) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.Status = s + } +} + +// WithStartContent sets the initial content for a tool_call start update. +func WithStartContent(c []ToolCallContent) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.Content = c + } +} + +// WithStartLocations sets file locations and, if a single path is provided and rawInput is empty, mirrors it as rawInput.path. +func WithStartLocations(l []ToolCallLocation) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.Locations = l + if len(l) == 1 && l[0].Path != "" { + if tc.RawInput == nil { + tc.RawInput = map[string]any{"path": l[0].Path} + } else { + m, ok := tc.RawInput.(map[string]any) + if ok { + if _, exists := m["path"]; !exists { + m["path"] = l[0].Path + } + } + } + } + } +} + +// WithStartRawInput sets rawInput for a tool_call start update. +func WithStartRawInput(v any) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.RawInput = v + } +} + +// WithStartRawOutput sets rawOutput for a tool_call start update. +func WithStartRawOutput(v any) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.RawOutput = v + } +} + +type ToolCallUpdateOpt func(tu *SessionUpdateToolCallUpdate) + +// UpdateToolCall constructs a tool_call_update with the given ID and applies optional modifiers. +func UpdateToolCall(id ToolCallId, opts ...ToolCallUpdateOpt) SessionUpdate { + tu := SessionUpdateToolCallUpdate{ToolCallId: id} + for _, opt := range opts { + opt(&tu) + } + return SessionUpdate{ToolCallUpdate: &tu} +} + +// WithUpdateTitle sets the title for a tool_call_update. +func WithUpdateTitle(t string) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Title = Ptr(t) + } +} + +// WithUpdateKind sets the kind for a tool_call_update. +func WithUpdateKind(k ToolKind) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Kind = Ptr(k) + } +} + +// WithUpdateStatus sets the status for a tool_call_update. +func WithUpdateStatus(s ToolCallStatus) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Status = Ptr(s) + } +} + +// WithUpdateContent replaces the content collection for a tool_call_update. +func WithUpdateContent(c []ToolCallContent) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Content = c + } +} + +// WithUpdateLocations replaces the locations collection for a tool_call_update. +func WithUpdateLocations(l []ToolCallLocation) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Locations = l + } +} + +// WithUpdateRawInput sets rawInput for a tool_call_update. +func WithUpdateRawInput(v any) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.RawInput = v + } +} + +// WithUpdateRawOutput sets rawOutput for a tool_call_update. +func WithUpdateRawOutput(v any) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.RawOutput = v + } +} + +// NewContentBlockText constructs a ContentBlock using the 'text' variant. +func NewContentBlockText(text string) ContentBlock { + return ContentBlock{Text: &ContentBlockText{ + Text: text, + Type: "text", + }} +} + +// NewContentBlockImage constructs a ContentBlock using the 'image' variant. +func NewContentBlockImage(data string, mimeType string) ContentBlock { + return ContentBlock{Image: &ContentBlockImage{ + Data: data, + MimeType: mimeType, + Type: "image", + }} +} + +// NewContentBlockAudio constructs a ContentBlock using the 'audio' variant. +func NewContentBlockAudio(data string, mimeType string) ContentBlock { + return ContentBlock{Audio: &ContentBlockAudio{ + Data: data, + MimeType: mimeType, + Type: "audio", + }} +} + +// NewContentBlockResourceLink constructs a ContentBlock using the 'resource_link' variant. +func NewContentBlockResourceLink(name string, uri string) ContentBlock { + return ContentBlock{ResourceLink: &ContentBlockResourceLink{ + Name: name, + Type: "resource_link", + Uri: uri, + }} +} + +// NewContentBlockResource constructs a ContentBlock using the 'resource' variant. +func NewContentBlockResource(resource EmbeddedResourceResource) ContentBlock { + return ContentBlock{Resource: &ContentBlockResource{ + Resource: resource, + Type: "resource", + }} +} + +// NewRequestPermissionOutcomeCancelled constructs a RequestPermissionOutcome using the 'cancelled' variant. +func NewRequestPermissionOutcomeCancelled() RequestPermissionOutcome { + return RequestPermissionOutcome{Cancelled: &RequestPermissionOutcomeCancelled{Outcome: "cancelled"}} +} + +// NewRequestPermissionOutcomeSelected constructs a RequestPermissionOutcome using the 'selected' variant. +func NewRequestPermissionOutcomeSelected(optionId PermissionOptionId) RequestPermissionOutcome { + return RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{ + OptionId: optionId, + Outcome: "selected", + }} +} + +// NewSessionUpdateUserMessageChunk constructs a SessionUpdate using the 'user_message_chunk' variant. +func NewSessionUpdateUserMessageChunk(content ContentBlock) SessionUpdate { + return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{ + Content: content, + SessionUpdate: "user_message_chunk", + }} +} + +// NewSessionUpdateAgentMessageChunk constructs a SessionUpdate using the 'agent_message_chunk' variant. +func NewSessionUpdateAgentMessageChunk(content ContentBlock) SessionUpdate { + return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{ + Content: content, + SessionUpdate: "agent_message_chunk", + }} +} + +// NewSessionUpdateAgentThoughtChunk constructs a SessionUpdate using the 'agent_thought_chunk' variant. +func NewSessionUpdateAgentThoughtChunk(content ContentBlock) SessionUpdate { + return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{ + Content: content, + SessionUpdate: "agent_thought_chunk", + }} +} + +// NewSessionUpdateToolCall constructs a SessionUpdate using the 'tool_call' variant. +func NewSessionUpdateToolCall(toolCallId ToolCallId, title string) SessionUpdate { + return SessionUpdate{ToolCall: &SessionUpdateToolCall{ + SessionUpdate: "tool_call", + Title: title, + ToolCallId: toolCallId, + }} +} + +// NewSessionUpdateToolCallUpdate constructs a SessionUpdate using the 'tool_call_update' variant. +func NewSessionUpdateToolCallUpdate(toolCallId ToolCallId) SessionUpdate { + return SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ + SessionUpdate: "tool_call_update", + ToolCallId: toolCallId, + }} +} + +// NewSessionUpdatePlan constructs a SessionUpdate using the 'plan' variant. +func NewSessionUpdatePlan(entries []PlanEntry) SessionUpdate { + return SessionUpdate{Plan: &SessionUpdatePlan{ + Entries: entries, + SessionUpdate: "plan", + }} +} + +// NewToolCallContentContent constructs a ToolCallContent using the 'content' variant. +func NewToolCallContentContent(content ContentBlock) ToolCallContent { + return ToolCallContent{Content: &ToolCallContentContent{ + Content: content, + Type: "content", + }} +} + +// NewToolCallContentDiff constructs a ToolCallContent using the 'diff' variant. +func NewToolCallContentDiff(path string, newText string) ToolCallContent { + return ToolCallContent{Diff: &ToolCallContentDiff{ + NewText: newText, + Path: path, + Type: "diff", + }} +} + +// NewToolCallContentTerminal constructs a ToolCallContent using the 'terminal' variant. +func NewToolCallContentTerminal(terminalId string) ToolCallContent { + return ToolCallContent{Terminal: &ToolCallContentTerminal{ + TerminalId: terminalId, + Type: "terminal", + }} +} + +// StartReadToolCall constructs a 'tool_call' update for reading a file: kind=read, status=pending, locations=[{path}], rawInput={path}. +func StartReadToolCall(id ToolCallId, title string, path string, opts ...ToolCallStartOpt) SessionUpdate { + base := []ToolCallStartOpt{WithStartKind(ToolKindRead), WithStartStatus(ToolCallStatusPending), WithStartLocations([]ToolCallLocation{ToolCallLocation{Path: path}}), WithStartRawInput(map[string]any{"path": path})} + args := append(base, opts...) + return StartToolCall(id, title, args...) +} + +// StartEditToolCall constructs a 'tool_call' update for editing content: kind=edit, status=pending, locations=[{path}], rawInput={path, content}. +func StartEditToolCall(id ToolCallId, title string, path string, content any, opts ...ToolCallStartOpt) SessionUpdate { + base := []ToolCallStartOpt{WithStartKind(ToolKindEdit), WithStartStatus(ToolCallStatusPending), WithStartLocations([]ToolCallLocation{ToolCallLocation{Path: path}}), WithStartRawInput(map[string]any{ + "content": content, + "path": path, + })} + args := append(base, opts...) + return StartToolCall(id, title, args...) +} diff --git a/go/json_parity_test.go b/go/json_parity_test.go new file mode 100644 index 0000000..4dc48c8 --- /dev/null +++ b/go/json_parity_test.go @@ -0,0 +1,209 @@ +package acp + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +// normalize unmarshals both sides to generic values and compare structurally. +func equalJSON(a, b []byte) (bool, string, string) { + var va any + var vb any + if err := json.Unmarshal(a, &va); err != nil { + return false, string(a), string(b) + } + if err := json.Unmarshal(b, &vb); err != nil { + return false, string(a), string(b) + } + return reflect.DeepEqual(va, vb), string(a), string(b) +} + +func mustReadGolden(t *testing.T, name string) []byte { + t.Helper() + p := filepath.Join("testdata", "json_golden", name) + b, err := os.ReadFile(p) + if err != nil { + t.Fatalf("read golden %s: %v", p, err) + } + return b +} + +// Generic golden runner for a specific type T +func runGolden[T any](t *testing.T, build func() T) { + t.Helper() + // Use the current subtest name; expect pattern like "/". + name := t.Name() + base := name + if i := strings.LastIndex(base, "/"); i >= 0 { + base = base[i+1:] + } + want := mustReadGolden(t, base+".json") + // Marshal from constructed value and compare with golden JSON. + got, err := json.Marshal(build()) + if err != nil { + t.Fatalf("marshal %s: %v", base, err) + } + if ok, ga, gw := equalJSON(got, want); !ok { + t.Fatalf("%s marshal mismatch\n got: %s\nwant: %s", base, ga, gw) + } + // Unmarshal golden into type, then marshal again and compare. + var v T + if err := json.Unmarshal(want, &v); err != nil { + t.Fatalf("unmarshal %s: %v", base, err) + } + round, err := json.Marshal(v) + if err != nil { + t.Fatalf("re-marshal %s: %v", base, err) + } + if ok, ga, gw := equalJSON(round, want); !ok { + t.Fatalf("%s round-trip mismatch\n got: %s\nwant: %s", base, ga, gw) + } +} + +func TestJSONGolden_ContentBlocks(t *testing.T) { + t.Run("content_text", func(t *testing.T) { + runGolden(t, func() ContentBlock { return TextBlock("What's the weather like today?") }) + }) + t.Run("content_image", func(t *testing.T) { + runGolden(t, func() ContentBlock { return ImageBlock("iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB...", "image/png") }) + }) + t.Run("content_audio", func(t *testing.T) { + runGolden(t, func() ContentBlock { return AudioBlock("UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB...", "audio/wav") }) + }) + t.Run("content_resource_text", func(t *testing.T) { + runGolden(t, func() ContentBlock { + res := EmbeddedResourceResource{TextResourceContents: &TextResourceContents{Uri: "file:///home/user/script.py", MimeType: Ptr("text/x-python"), Text: "def hello():\n print('Hello, world!')"}} + return ResourceBlock(EmbeddedResource{Resource: res}) + }) + }) + t.Run("content_resource_blob", func(t *testing.T) { + runGolden(t, func() ContentBlock { + res := EmbeddedResourceResource{BlobResourceContents: &BlobResourceContents{Uri: "file:///home/user/document.pdf", MimeType: Ptr("application/pdf"), Blob: ""}} + return ResourceBlock(EmbeddedResource{Resource: res}) + }) + }) + t.Run("content_resource_link", func(t *testing.T) { + runGolden(t, func() ContentBlock { + mt := "application/pdf" + sz := 1024000 + return ContentBlock{ResourceLink: &ContentBlockResourceLink{Type: "resource_link", Uri: "file:///home/user/document.pdf", Name: "document.pdf", MimeType: &mt, Size: &sz}} + }) + }) +} + +func TestJSONGolden_ToolCallContent(t *testing.T) { + t.Run("tool_content_content_text", func(t *testing.T) { + runGolden(t, func() ToolCallContent { return ToolContent(TextBlock("Analysis complete. Found 3 issues.")) }) + }) + t.Run("tool_content_diff", func(t *testing.T) { + runGolden(t, func() ToolCallContent { + old := "{\n \"debug\": false\n}" + return ToolDiffContent("/home/user/project/src/config.json", "{\n \"debug\": true\n}", old) + }) + }) +} + +func TestJSONGolden_RequestPermissionOutcome(t *testing.T) { + t.Run("permission_outcome_selected", func(t *testing.T) { + runGolden(t, func() RequestPermissionOutcome { + return RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{Outcome: "selected", OptionId: "allow-once"}} + }) + }) + t.Run("permission_outcome_cancelled", func(t *testing.T) { + runGolden(t, func() RequestPermissionOutcome { + return RequestPermissionOutcome{Cancelled: &RequestPermissionOutcomeCancelled{Outcome: "cancelled"}} + }) + }) +} + +func TestJSONGolden_SessionUpdates(t *testing.T) { + t.Run("session_update_user_message_chunk", func(t *testing.T) { + runGolden(t, func() SessionUpdate { + return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: TextBlock("What's the capital of France?")}} + }) + }) + t.Run("session_update_agent_message_chunk", func(t *testing.T) { + runGolden(t, func() SessionUpdate { + return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: TextBlock("The capital of France is Paris.")}} + }) + }) + t.Run("session_update_agent_thought_chunk", func(t *testing.T) { + runGolden(t, func() SessionUpdate { + return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{Content: TextBlock("Thinking about best approach...")}} + }) + }) + t.Run("session_update_plan", func(t *testing.T) { + runGolden(t, func() SessionUpdate { + return SessionUpdate{Plan: &SessionUpdatePlan{Entries: []PlanEntry{{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, {Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}}}} + }) + }) + t.Run("session_update_tool_call", func(t *testing.T) { + runGolden(t, func() SessionUpdate { + return SessionUpdate{ToolCall: &SessionUpdateToolCall{ToolCallId: "call_001", Title: "Reading configuration file", Kind: ToolKindRead, Status: ToolCallStatusPending}} + }) + }) + t.Run("session_update_tool_call_update_content", func(t *testing.T) { + runGolden(t, func() SessionUpdate { + return SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ToolCallId: "call_001", Status: Ptr(ToolCallStatusInProgress), Content: []ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))}}} + }) + }) +} + +func TestJSONGolden_MethodPayloads(t *testing.T) { + t.Run("initialize_request", func(t *testing.T) { + runGolden(t, func() InitializeRequest { + return InitializeRequest{ProtocolVersion: 1, ClientCapabilities: ClientCapabilities{Fs: FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}} + }) + }) + t.Run("initialize_response", func(t *testing.T) { + runGolden(t, func() InitializeResponse { + return InitializeResponse{ProtocolVersion: 1, AgentCapabilities: AgentCapabilities{LoadSession: true, PromptCapabilities: PromptCapabilities{Image: true, Audio: true, EmbeddedContext: true}}, AuthMethods: []AuthMethod{}} + }) + }) + t.Run("new_session_request", func(t *testing.T) { + runGolden(t, func() NewSessionRequest { + return NewSessionRequest{Cwd: "/home/user/project", McpServers: []McpServer{{Name: "filesystem", Command: "/path/to/mcp-server", Args: []string{"--stdio"}, Env: []EnvVariable{}}}} + }) + }) + t.Run("new_session_response", func(t *testing.T) { + runGolden(t, func() NewSessionResponse { return NewSessionResponse{SessionId: "sess_abc123def456"} }) + }) + t.Run("prompt_request", func(t *testing.T) { + runGolden(t, func() PromptRequest { + return PromptRequest{SessionId: "sess_abc123def456", Prompt: []ContentBlock{TextBlock("Can you analyze this code for potential issues?"), ResourceBlock(EmbeddedResource{Resource: EmbeddedResourceResource{TextResourceContents: &TextResourceContents{Uri: "file:///home/user/project/main.py", MimeType: Ptr("text/x-python"), Text: "def process_data(items):\n for item in items:\n print(item)"}}})}} + }) + }) + t.Run("fs_read_text_file_request", func(t *testing.T) { + runGolden(t, func() ReadTextFileRequest { + line, limit := 10, 50 + return ReadTextFileRequest{SessionId: "sess_abc123def456", Path: "/home/user/project/src/main.py", Line: &line, Limit: &limit} + }) + }) + t.Run("fs_read_text_file_response", func(t *testing.T) { + runGolden(t, func() ReadTextFileResponse { + return ReadTextFileResponse{Content: "def hello_world():\n print('Hello, world!')\n"} + }) + }) + t.Run("fs_write_text_file_request", func(t *testing.T) { + runGolden(t, func() WriteTextFileRequest { + return WriteTextFileRequest{SessionId: "sess_abc123def456", Path: "/home/user/project/config.json", Content: "{\n \"debug\": true,\n \"version\": \"1.0.0\"\n}"} + }) + }) + t.Run("request_permission_request", func(t *testing.T) { + runGolden(t, func() RequestPermissionRequest { + return RequestPermissionRequest{SessionId: "sess_abc123def456", ToolCall: ToolCallUpdate{ToolCallId: "call_001"}, Options: []PermissionOption{{OptionId: "allow-once", Name: "Allow once", Kind: PermissionOptionKindAllowOnce}, {OptionId: "reject-once", Name: "Reject", Kind: PermissionOptionKindRejectOnce}}} + }) + }) + t.Run("request_permission_response_selected", func(t *testing.T) { + runGolden(t, func() RequestPermissionResponse { + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{Outcome: "selected", OptionId: "allow-once"}}} + }) + }) + t.Run("cancel_notification", func(t *testing.T) { + runGolden(t, func() CancelNotification { return CancelNotification{SessionId: "sess_abc123def456"} }) + }) +} diff --git a/go/testdata/json_golden/cancel_notification.json b/go/testdata/json_golden/cancel_notification.json new file mode 100644 index 0000000..f1b3079 --- /dev/null +++ b/go/testdata/json_golden/cancel_notification.json @@ -0,0 +1,4 @@ +{ + "sessionId": "sess_abc123def456" +} + diff --git a/go/testdata/json_golden/content_audio.json b/go/testdata/json_golden/content_audio.json new file mode 100644 index 0000000..6474c1d --- /dev/null +++ b/go/testdata/json_golden/content_audio.json @@ -0,0 +1,6 @@ +{ + "type": "audio", + "mimeType": "audio/wav", + "data": "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB..." +} + diff --git a/go/testdata/json_golden/content_image.json b/go/testdata/json_golden/content_image.json new file mode 100644 index 0000000..87fd55c --- /dev/null +++ b/go/testdata/json_golden/content_image.json @@ -0,0 +1,6 @@ +{ + "type": "image", + "mimeType": "image/png", + "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB..." +} + diff --git a/go/testdata/json_golden/content_resource_blob.json b/go/testdata/json_golden/content_resource_blob.json new file mode 100644 index 0000000..121eda3 --- /dev/null +++ b/go/testdata/json_golden/content_resource_blob.json @@ -0,0 +1,9 @@ +{ + "type": "resource", + "resource": { + "uri": "file:///home/user/document.pdf", + "mimeType": "application/pdf", + "blob": "" + } +} + diff --git a/go/testdata/json_golden/content_resource_link.json b/go/testdata/json_golden/content_resource_link.json new file mode 100644 index 0000000..82a0a30 --- /dev/null +++ b/go/testdata/json_golden/content_resource_link.json @@ -0,0 +1,8 @@ +{ + "type": "resource_link", + "uri": "file:///home/user/document.pdf", + "name": "document.pdf", + "mimeType": "application/pdf", + "size": 1024000 +} + diff --git a/go/testdata/json_golden/content_resource_text.json b/go/testdata/json_golden/content_resource_text.json new file mode 100644 index 0000000..80f40fa --- /dev/null +++ b/go/testdata/json_golden/content_resource_text.json @@ -0,0 +1,9 @@ +{ + "type": "resource", + "resource": { + "uri": "file:///home/user/script.py", + "mimeType": "text/x-python", + "text": "def hello():\n print('Hello, world!')" + } +} + diff --git a/go/testdata/json_golden/content_text.json b/go/testdata/json_golden/content_text.json new file mode 100644 index 0000000..eca6047 --- /dev/null +++ b/go/testdata/json_golden/content_text.json @@ -0,0 +1,5 @@ +{ + "type": "text", + "text": "What's the weather like today?" +} + diff --git a/go/testdata/json_golden/fs_read_text_file_request.json b/go/testdata/json_golden/fs_read_text_file_request.json new file mode 100644 index 0000000..ec25ffa --- /dev/null +++ b/go/testdata/json_golden/fs_read_text_file_request.json @@ -0,0 +1,7 @@ +{ + "sessionId": "sess_abc123def456", + "path": "/home/user/project/src/main.py", + "line": 10, + "limit": 50 +} + diff --git a/go/testdata/json_golden/fs_read_text_file_response.json b/go/testdata/json_golden/fs_read_text_file_response.json new file mode 100644 index 0000000..68c73a3 --- /dev/null +++ b/go/testdata/json_golden/fs_read_text_file_response.json @@ -0,0 +1,4 @@ +{ + "content": "def hello_world():\n print('Hello, world!')\n" +} + diff --git a/go/testdata/json_golden/fs_write_text_file_request.json b/go/testdata/json_golden/fs_write_text_file_request.json new file mode 100644 index 0000000..eb72e39 --- /dev/null +++ b/go/testdata/json_golden/fs_write_text_file_request.json @@ -0,0 +1,6 @@ +{ + "sessionId": "sess_abc123def456", + "path": "/home/user/project/config.json", + "content": "{\n \"debug\": true,\n \"version\": \"1.0.0\"\n}" +} + diff --git a/go/testdata/json_golden/initialize_request.json b/go/testdata/json_golden/initialize_request.json new file mode 100644 index 0000000..5f689c7 --- /dev/null +++ b/go/testdata/json_golden/initialize_request.json @@ -0,0 +1,10 @@ +{ + "protocolVersion": 1, + "clientCapabilities": { + "fs": { + "readTextFile": true, + "writeTextFile": true + } + } +} + diff --git a/go/testdata/json_golden/initialize_response.json b/go/testdata/json_golden/initialize_response.json new file mode 100644 index 0000000..499e8fa --- /dev/null +++ b/go/testdata/json_golden/initialize_response.json @@ -0,0 +1,13 @@ +{ + "protocolVersion": 1, + "agentCapabilities": { + "loadSession": true, + "promptCapabilities": { + "image": true, + "audio": true, + "embeddedContext": true + } + }, + "authMethods": [] +} + diff --git a/go/testdata/json_golden/new_session_request.json b/go/testdata/json_golden/new_session_request.json new file mode 100644 index 0000000..271df0f --- /dev/null +++ b/go/testdata/json_golden/new_session_request.json @@ -0,0 +1,12 @@ +{ + "cwd": "/home/user/project", + "mcpServers": [ + { + "name": "filesystem", + "command": "/path/to/mcp-server", + "args": ["--stdio"], + "env": [] + } + ] +} + diff --git a/go/testdata/json_golden/new_session_response.json b/go/testdata/json_golden/new_session_response.json new file mode 100644 index 0000000..f1b3079 --- /dev/null +++ b/go/testdata/json_golden/new_session_response.json @@ -0,0 +1,4 @@ +{ + "sessionId": "sess_abc123def456" +} + diff --git a/go/testdata/json_golden/permission_outcome_cancelled.json b/go/testdata/json_golden/permission_outcome_cancelled.json new file mode 100644 index 0000000..3fe3090 --- /dev/null +++ b/go/testdata/json_golden/permission_outcome_cancelled.json @@ -0,0 +1,4 @@ +{ + "outcome": "cancelled" +} + diff --git a/go/testdata/json_golden/permission_outcome_selected.json b/go/testdata/json_golden/permission_outcome_selected.json new file mode 100644 index 0000000..3c79f8f --- /dev/null +++ b/go/testdata/json_golden/permission_outcome_selected.json @@ -0,0 +1,5 @@ +{ + "outcome": "selected", + "optionId": "allow-once" +} + diff --git a/go/testdata/json_golden/prompt_request.json b/go/testdata/json_golden/prompt_request.json new file mode 100644 index 0000000..c4da8fb --- /dev/null +++ b/go/testdata/json_golden/prompt_request.json @@ -0,0 +1,18 @@ +{ + "sessionId": "sess_abc123def456", + "prompt": [ + { + "type": "text", + "text": "Can you analyze this code for potential issues?" + }, + { + "type": "resource", + "resource": { + "uri": "file:///home/user/project/main.py", + "mimeType": "text/x-python", + "text": "def process_data(items):\n for item in items:\n print(item)" + } + } + ] +} + diff --git a/go/testdata/json_golden/request_permission_request.json b/go/testdata/json_golden/request_permission_request.json new file mode 100644 index 0000000..f845f99 --- /dev/null +++ b/go/testdata/json_golden/request_permission_request.json @@ -0,0 +1,19 @@ +{ + "sessionId": "sess_abc123def456", + "toolCall": { + "toolCallId": "call_001" + }, + "options": [ + { + "optionId": "allow-once", + "name": "Allow once", + "kind": "allow_once" + }, + { + "optionId": "reject-once", + "name": "Reject", + "kind": "reject_once" + } + ] +} + diff --git a/go/testdata/json_golden/request_permission_response_selected.json b/go/testdata/json_golden/request_permission_response_selected.json new file mode 100644 index 0000000..98df2fc --- /dev/null +++ b/go/testdata/json_golden/request_permission_response_selected.json @@ -0,0 +1,7 @@ +{ + "outcome": { + "outcome": "selected", + "optionId": "allow-once" + } +} + diff --git a/go/testdata/json_golden/session_update_agent_message_chunk.json b/go/testdata/json_golden/session_update_agent_message_chunk.json new file mode 100644 index 0000000..9b9f6d0 --- /dev/null +++ b/go/testdata/json_golden/session_update_agent_message_chunk.json @@ -0,0 +1,8 @@ +{ + "sessionUpdate": "agent_message_chunk", + "content": { + "type": "text", + "text": "The capital of France is Paris." + } +} + diff --git a/go/testdata/json_golden/session_update_agent_thought_chunk.json b/go/testdata/json_golden/session_update_agent_thought_chunk.json new file mode 100644 index 0000000..b331119 --- /dev/null +++ b/go/testdata/json_golden/session_update_agent_thought_chunk.json @@ -0,0 +1,8 @@ +{ + "sessionUpdate": "agent_thought_chunk", + "content": { + "type": "text", + "text": "Thinking about best approach..." + } +} + diff --git a/go/testdata/json_golden/session_update_plan.json b/go/testdata/json_golden/session_update_plan.json new file mode 100644 index 0000000..744b7cd --- /dev/null +++ b/go/testdata/json_golden/session_update_plan.json @@ -0,0 +1,16 @@ +{ + "sessionUpdate": "plan", + "entries": [ + { + "content": "Check for syntax errors", + "priority": "high", + "status": "pending" + }, + { + "content": "Identify potential type issues", + "priority": "medium", + "status": "pending" + } + ] +} + diff --git a/go/testdata/json_golden/session_update_tool_call.json b/go/testdata/json_golden/session_update_tool_call.json new file mode 100644 index 0000000..0ad4ce4 --- /dev/null +++ b/go/testdata/json_golden/session_update_tool_call.json @@ -0,0 +1,8 @@ +{ + "sessionUpdate": "tool_call", + "toolCallId": "call_001", + "title": "Reading configuration file", + "kind": "read", + "status": "pending" +} + diff --git a/go/testdata/json_golden/session_update_tool_call_update_content.json b/go/testdata/json_golden/session_update_tool_call_update_content.json new file mode 100644 index 0000000..6db93a5 --- /dev/null +++ b/go/testdata/json_golden/session_update_tool_call_update_content.json @@ -0,0 +1,15 @@ +{ + "sessionUpdate": "tool_call_update", + "toolCallId": "call_001", + "status": "in_progress", + "content": [ + { + "type": "content", + "content": { + "type": "text", + "text": "Found 3 configuration files..." + } + } + ] +} + diff --git a/go/testdata/json_golden/session_update_user_message_chunk.json b/go/testdata/json_golden/session_update_user_message_chunk.json new file mode 100644 index 0000000..7944b98 --- /dev/null +++ b/go/testdata/json_golden/session_update_user_message_chunk.json @@ -0,0 +1,8 @@ +{ + "sessionUpdate": "user_message_chunk", + "content": { + "type": "text", + "text": "What's the capital of France?" + } +} + diff --git a/go/testdata/json_golden/tool_content_content_text.json b/go/testdata/json_golden/tool_content_content_text.json new file mode 100644 index 0000000..820a413 --- /dev/null +++ b/go/testdata/json_golden/tool_content_content_text.json @@ -0,0 +1,8 @@ +{ + "type": "content", + "content": { + "type": "text", + "text": "Analysis complete. Found 3 issues." + } +} + diff --git a/go/testdata/json_golden/tool_content_diff.json b/go/testdata/json_golden/tool_content_diff.json new file mode 100644 index 0000000..c1755dc --- /dev/null +++ b/go/testdata/json_golden/tool_content_diff.json @@ -0,0 +1,7 @@ +{ + "type": "diff", + "path": "/home/user/project/src/config.json", + "oldText": "{\n \"debug\": false\n}", + "newText": "{\n \"debug\": true\n}" +} + diff --git a/go/types_gen.go b/go/types_gen.go index 0579924..8ede4b8 100644 --- a/go/types_gen.go +++ b/go/types_gen.go @@ -5,6 +5,7 @@ package acp import ( "context" "encoding/json" + "errors" "fmt" ) @@ -17,16 +18,326 @@ type AgentCapabilities struct { } // All possible notifications that an agent can send to a client. This enum is used internally for routing RPC notifications. You typically won't need to use this directly - use the notification methods on the ['Client'] trait instead. Notifications do not expect a response. -// AgentNotification is a union or complex schema; represented generically. -type AgentNotification any +type AgentNotification struct { + SessionNotification *SessionNotification `json:"-"` +} + +func (u *AgentNotification) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + { + var v SessionNotification + if json.Unmarshal(b, &v) == nil { + u.SessionNotification = &v + return nil + } + } + return nil +} +func (u AgentNotification) MarshalJSON() ([]byte, error) { + if u.SessionNotification != nil { + var m map[string]any + _b, _e := json.Marshal(*u.SessionNotification) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + return []byte{}, nil +} // All possible requests that an agent can send to a client. This enum is used internally for routing RPC requests. You typically won't need to use this directly - instead, use the methods on the ['Client'] trait. This enum encompasses all method calls from agent to client. -// AgentRequest is a union or complex schema; represented generically. -type AgentRequest any +type AgentRequest struct { + WriteTextFileRequest *WriteTextFileRequest `json:"-"` + ReadTextFileRequest *ReadTextFileRequest `json:"-"` + RequestPermissionRequest *RequestPermissionRequest `json:"-"` + CreateTerminalRequest *CreateTerminalRequest `json:"-"` + TerminalOutputRequest *TerminalOutputRequest `json:"-"` + ReleaseTerminalRequest *ReleaseTerminalRequest `json:"-"` + WaitForTerminalExitRequest *WaitForTerminalExitRequest `json:"-"` + KillTerminalRequest *KillTerminalRequest `json:"-"` +} + +func (u *AgentRequest) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + { + var v WriteTextFileRequest + if json.Unmarshal(b, &v) == nil { + u.WriteTextFileRequest = &v + return nil + } + } + { + var v ReadTextFileRequest + if json.Unmarshal(b, &v) == nil { + u.ReadTextFileRequest = &v + return nil + } + } + { + var v RequestPermissionRequest + if json.Unmarshal(b, &v) == nil { + u.RequestPermissionRequest = &v + return nil + } + } + { + var v CreateTerminalRequest + if json.Unmarshal(b, &v) == nil { + u.CreateTerminalRequest = &v + return nil + } + } + { + var v TerminalOutputRequest + if json.Unmarshal(b, &v) == nil { + u.TerminalOutputRequest = &v + return nil + } + } + { + var v ReleaseTerminalRequest + if json.Unmarshal(b, &v) == nil { + u.ReleaseTerminalRequest = &v + return nil + } + } + { + var v WaitForTerminalExitRequest + if json.Unmarshal(b, &v) == nil { + u.WaitForTerminalExitRequest = &v + return nil + } + } + { + var v KillTerminalRequest + if json.Unmarshal(b, &v) == nil { + u.KillTerminalRequest = &v + return nil + } + } + return nil +} +func (u AgentRequest) MarshalJSON() ([]byte, error) { + if u.WriteTextFileRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.WriteTextFileRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.ReadTextFileRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.ReadTextFileRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.RequestPermissionRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.RequestPermissionRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.CreateTerminalRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.CreateTerminalRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.TerminalOutputRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.TerminalOutputRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.ReleaseTerminalRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.ReleaseTerminalRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.WaitForTerminalExitRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.WaitForTerminalExitRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.KillTerminalRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.KillTerminalRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + return []byte{}, nil +} // All possible responses that an agent can send to a client. This enum is used internally for routing RPC responses. You typically won't need to use this directly - the responses are handled automatically by the connection. These are responses to the corresponding ClientRequest variants. -// AgentResponse is a union or complex schema; represented generically. -type AgentResponse any +type AuthenticateResponse struct{} + +type LoadSessionResponse struct{} + +type AgentResponse struct { + InitializeResponse *InitializeResponse `json:"-"` + AuthenticateResponse *AuthenticateResponse `json:"-"` + NewSessionResponse *NewSessionResponse `json:"-"` + LoadSessionResponse *LoadSessionResponse `json:"-"` + PromptResponse *PromptResponse `json:"-"` +} + +func (u *AgentResponse) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + var v AuthenticateResponse + u.AuthenticateResponse = &v + return nil + } + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + { + var v InitializeResponse + if json.Unmarshal(b, &v) == nil { + u.InitializeResponse = &v + return nil + } + } + { + var v AuthenticateResponse + if json.Unmarshal(b, &v) == nil { + u.AuthenticateResponse = &v + return nil + } + } + { + var v NewSessionResponse + if json.Unmarshal(b, &v) == nil { + u.NewSessionResponse = &v + return nil + } + } + { + var v LoadSessionResponse + if json.Unmarshal(b, &v) == nil { + u.LoadSessionResponse = &v + return nil + } + } + { + var v PromptResponse + if json.Unmarshal(b, &v) == nil { + u.PromptResponse = &v + return nil + } + } + return nil +} +func (u AgentResponse) MarshalJSON() ([]byte, error) { + if u.InitializeResponse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.InitializeResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.AuthenticateResponse != nil { + return json.Marshal(nil) + var m map[string]any + _b, _e := json.Marshal(*u.AuthenticateResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.NewSessionResponse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.NewSessionResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.LoadSessionResponse != nil { + return json.Marshal(nil) + var m map[string]any + _b, _e := json.Marshal(*u.LoadSessionResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.PromptResponse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.PromptResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + return []byte{}, nil +} // Optional annotations for the client. The client can use annotations to inform how objects are used or displayed type Annotations struct { @@ -97,8 +408,8 @@ func (u *AvailableCommandInput) UnmarshalJSON(b []byte) error { match = false } if match { - if err := json.Unmarshal(b, &v); err != nil { - return err + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") } u.UnstructuredCommandInput = &v return nil @@ -106,7 +417,7 @@ func (u *AvailableCommandInput) UnmarshalJSON(b []byte) error { } { var v UnstructuredCommandInput - if err := json.Unmarshal(b, &v); err == nil { + if json.Unmarshal(b, &v) == nil { u.UnstructuredCommandInput = &v return nil } @@ -115,7 +426,15 @@ func (u *AvailableCommandInput) UnmarshalJSON(b []byte) error { } func (u AvailableCommandInput) MarshalJSON() ([]byte, error) { if u.UnstructuredCommandInput != nil { - return json.Marshal(*u.UnstructuredCommandInput) + var m map[string]any + _b, _e := json.Marshal(*u.UnstructuredCommandInput) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) } return []byte{}, nil } @@ -146,151 +465,693 @@ type ClientCapabilities struct { } // All possible notifications that a client can send to an agent. This enum is used internally for routing RPC notifications. You typically won't need to use this directly - use the notification methods on the ['Agent'] trait instead. Notifications do not expect a response. -// ClientNotification is a union or complex schema; represented generically. -type ClientNotification any +type ClientNotification struct { + CancelNotification *CancelNotification `json:"-"` +} + +func (u *ClientNotification) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + { + var v CancelNotification + if json.Unmarshal(b, &v) == nil { + u.CancelNotification = &v + return nil + } + } + return nil +} +func (u ClientNotification) MarshalJSON() ([]byte, error) { + if u.CancelNotification != nil { + var m map[string]any + _b, _e := json.Marshal(*u.CancelNotification) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + return []byte{}, nil +} // All possible requests that a client can send to an agent. This enum is used internally for routing RPC requests. You typically won't need to use this directly - instead, use the methods on the ['Agent'] trait. This enum encompasses all method calls from client to agent. -// ClientRequest is a union or complex schema; represented generically. -type ClientRequest any +type ClientRequest struct { + InitializeRequest *InitializeRequest `json:"-"` + AuthenticateRequest *AuthenticateRequest `json:"-"` + NewSessionRequest *NewSessionRequest `json:"-"` + LoadSessionRequest *LoadSessionRequest `json:"-"` + PromptRequest *PromptRequest `json:"-"` +} + +func (u *ClientRequest) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + { + var v InitializeRequest + if json.Unmarshal(b, &v) == nil { + u.InitializeRequest = &v + return nil + } + } + { + var v AuthenticateRequest + if json.Unmarshal(b, &v) == nil { + u.AuthenticateRequest = &v + return nil + } + } + { + var v NewSessionRequest + if json.Unmarshal(b, &v) == nil { + u.NewSessionRequest = &v + return nil + } + } + { + var v LoadSessionRequest + if json.Unmarshal(b, &v) == nil { + u.LoadSessionRequest = &v + return nil + } + } + { + var v PromptRequest + if json.Unmarshal(b, &v) == nil { + u.PromptRequest = &v + return nil + } + } + return nil +} +func (u ClientRequest) MarshalJSON() ([]byte, error) { + if u.InitializeRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.InitializeRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.AuthenticateRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.AuthenticateRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.NewSessionRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.NewSessionRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.LoadSessionRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.LoadSessionRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.PromptRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.PromptRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + return []byte{}, nil +} // All possible responses that a client can send to an agent. This enum is used internally for routing RPC responses. You typically won't need to use this directly - the responses are handled automatically by the connection. These are responses to the corresponding AgentRequest variants. -// ClientResponse is a union or complex schema; represented generically. -type ClientResponse any +type WriteTextFileResponse struct{} -// Content blocks represent displayable information in the Agent Client Protocol. They provide a structured way to handle various types of user-facing content—whether it's text from language models, images for analysis, or embedded resources for context. Content blocks appear in: - User prompts sent via 'session/prompt' - Language model output streamed through 'session/update' notifications - Progress updates and results from tool calls This structure is compatible with the Model Context Protocol (MCP), enabling agents to seamlessly forward content from MCP tool outputs without transformation. See protocol docs: [Content](https://agentclientprotocol.com/protocol/content) -type ResourceLinkContent struct { - Annotations any `json:"annotations,omitempty"` - Description *string `json:"description,omitempty"` - MimeType *string `json:"mimeType,omitempty"` - Name string `json:"name"` - Size *int64 `json:"size,omitempty"` - Title *string `json:"title,omitempty"` - Uri string `json:"uri"` -} +type ReleaseTerminalResponse struct{} -type ContentBlock struct { - Type string `json:"type"` - Text *TextContent `json:"-"` - Image *ImageContent `json:"-"` - Audio *AudioContent `json:"-"` - ResourceLink *ResourceLinkContent `json:"-"` - Resource *EmbeddedResource `json:"-"` +type KillTerminalResponse struct{} + +type ClientResponse struct { + WriteTextFileResponse *WriteTextFileResponse `json:"-"` + ReadTextFileResponse *ReadTextFileResponse `json:"-"` + RequestPermissionResponse *RequestPermissionResponse `json:"-"` + CreateTerminalResponse *CreateTerminalResponse `json:"-"` + TerminalOutputResponse *TerminalOutputResponse `json:"-"` + ReleaseTerminalResponse *ReleaseTerminalResponse `json:"-"` + WaitForTerminalExitResponse *WaitForTerminalExitResponse `json:"-"` + KillTerminalResponse *KillTerminalResponse `json:"-"` } -func (c *ContentBlock) UnmarshalJSON(b []byte) error { - var probe struct { - Type string `json:"type"` +func (u *ClientResponse) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + var v WriteTextFileResponse + u.WriteTextFileResponse = &v + return nil } - if err := json.Unmarshal(b, &probe); err != nil { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { return err } - c.Type = probe.Type - switch probe.Type { - case "text": - var v TextContent - if err := json.Unmarshal(b, &v); err != nil { - return err - } - c.Text = &v - case "image": - var v ImageContent - if err := json.Unmarshal(b, &v); err != nil { - return err - } - c.Image = &v - case "audio": - var v AudioContent - if err := json.Unmarshal(b, &v); err != nil { - return err - } - c.Audio = &v - case "resource_link": - var v ResourceLinkContent - if err := json.Unmarshal(b, &v); err != nil { - return err - } - c.ResourceLink = &v - case "resource": - var v EmbeddedResource - if err := json.Unmarshal(b, &v); err != nil { - return err - } - c.Resource = &v + { + var v WriteTextFileResponse + if json.Unmarshal(b, &v) == nil { + u.WriteTextFileResponse = &v + return nil + } + } + { + var v ReadTextFileResponse + if json.Unmarshal(b, &v) == nil { + u.ReadTextFileResponse = &v + return nil + } + } + { + var v RequestPermissionResponse + if json.Unmarshal(b, &v) == nil { + u.RequestPermissionResponse = &v + return nil + } + } + { + var v CreateTerminalResponse + if json.Unmarshal(b, &v) == nil { + u.CreateTerminalResponse = &v + return nil + } + } + { + var v TerminalOutputResponse + if json.Unmarshal(b, &v) == nil { + u.TerminalOutputResponse = &v + return nil + } + } + { + var v ReleaseTerminalResponse + if json.Unmarshal(b, &v) == nil { + u.ReleaseTerminalResponse = &v + return nil + } + } + { + var v WaitForTerminalExitResponse + if json.Unmarshal(b, &v) == nil { + u.WaitForTerminalExitResponse = &v + return nil + } + } + { + var v KillTerminalResponse + if json.Unmarshal(b, &v) == nil { + u.KillTerminalResponse = &v + return nil + } } return nil } -func (c ContentBlock) MarshalJSON() ([]byte, error) { - switch c.Type { - case "text": - if c.Text != nil { - return json.Marshal(map[string]any{ - "text": c.Text.Text, - "type": "text", - }) - } - case "image": - if c.Image != nil { - return json.Marshal(map[string]any{ - "data": c.Image.Data, - "mimeType": c.Image.MimeType, - "type": "image", - "uri": c.Image.Uri, - }) - } - case "audio": - if c.Audio != nil { - return json.Marshal(map[string]any{ - "data": c.Audio.Data, - "mimeType": c.Audio.MimeType, - "type": "audio", - }) - } - case "resource_link": - if c.ResourceLink != nil { - return json.Marshal(map[string]any{ - "description": c.ResourceLink.Description, - "mimeType": c.ResourceLink.MimeType, - "name": c.ResourceLink.Name, - "size": c.ResourceLink.Size, - "title": c.ResourceLink.Title, - "type": "resource_link", - "uri": c.ResourceLink.Uri, - }) - } - case "resource": - if c.Resource != nil { - return json.Marshal(map[string]any{ - "resource": c.Resource.Resource, - "type": "resource", - }) +func (u ClientResponse) MarshalJSON() ([]byte, error) { + if u.WriteTextFileResponse != nil { + return json.Marshal(nil) + var m map[string]any + _b, _e := json.Marshal(*u.WriteTextFileResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.ReadTextFileResponse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.ReadTextFileResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.RequestPermissionResponse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.RequestPermissionResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.CreateTerminalResponse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.CreateTerminalResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.TerminalOutputResponse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.TerminalOutputResponse) + if _e != nil { + return []byte{}, _e } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.ReleaseTerminalResponse != nil { + return json.Marshal(nil) + var m map[string]any + _b, _e := json.Marshal(*u.ReleaseTerminalResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.WaitForTerminalExitResponse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.WaitForTerminalExitResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.KillTerminalResponse != nil { + return json.Marshal(nil) + var m map[string]any + _b, _e := json.Marshal(*u.KillTerminalResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) } return []byte{}, nil } -func (c *ContentBlock) Validate() error { - switch c.Type { - case "text": - if c.Text == nil { - return fmt.Errorf("contentblock.text missing") +// Content blocks represent displayable information in the Agent Client Protocol. They provide a structured way to handle various types of user-facing content—whether it's text from language models, images for analysis, or embedded resources for context. Content blocks appear in: - User prompts sent via 'session/prompt' - Language model output streamed through 'session/update' notifications - Progress updates and results from tool calls This structure is compatible with the Model Context Protocol (MCP), enabling agents to seamlessly forward content from MCP tool outputs without transformation. See protocol docs: [Content](https://agentclientprotocol.com/protocol/content) +// Plain text content All agents MUST support text content blocks in prompts. +type ContentBlockText struct { + Annotations *Annotations `json:"annotations,omitempty"` + Text string `json:"text"` + Type string `json:"type"` +} + +// Images for visual context or analysis. Requires the 'image' prompt capability when included in prompts. +type ContentBlockImage struct { + Annotations *Annotations `json:"annotations,omitempty"` + Data string `json:"data"` + MimeType string `json:"mimeType"` + Type string `json:"type"` + Uri *string `json:"uri,omitempty"` +} + +// Audio data for transcription or analysis. Requires the 'audio' prompt capability when included in prompts. +type ContentBlockAudio struct { + Annotations *Annotations `json:"annotations,omitempty"` + Data string `json:"data"` + MimeType string `json:"mimeType"` + Type string `json:"type"` +} + +// References to resources that the agent can access. All agents MUST support resource links in prompts. +type ContentBlockResourceLink struct { + Annotations *Annotations `json:"annotations,omitempty"` + Description *string `json:"description,omitempty"` + MimeType *string `json:"mimeType,omitempty"` + Name string `json:"name"` + Size *int `json:"size,omitempty"` + Title *string `json:"title,omitempty"` + Type string `json:"type"` + Uri string `json:"uri"` +} + +// Complete resource contents embedded directly in the message. Preferred for including context as it avoids extra round-trips. Requires the 'embeddedContext' prompt capability when included in prompts. +type ContentBlockResource struct { + Annotations *Annotations `json:"annotations,omitempty"` + Resource EmbeddedResourceResource `json:"resource"` + Type string `json:"type"` +} + +type ContentBlock struct { + Text *ContentBlockText `json:"-"` + Image *ContentBlockImage `json:"-"` + Audio *ContentBlockAudio `json:"-"` + ResourceLink *ContentBlockResourceLink `json:"-"` + Resource *ContentBlockResource `json:"-"` +} + +func (u *ContentBlock) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + { + var disc string + if v, ok := m["type"]; ok { + json.Unmarshal(v, &disc) + } + switch disc { + case "text": + var v ContentBlockText + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Text = &v + return nil + case "image": + var v ContentBlockImage + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Image = &v + return nil + case "audio": + var v ContentBlockAudio + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Audio = &v + return nil + case "resource_link": + var v ContentBlockResourceLink + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.ResourceLink = &v + return nil + case "resource": + var v ContentBlockResource + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Resource = &v + return nil + } + } + { + var v ContentBlockText + var match bool = true + if _, ok := m["type"]; !ok { + match = false + } + if _, ok := m["text"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Text = &v + return nil + } + } + { + var v ContentBlockImage + var match bool = true + if _, ok := m["type"]; !ok { + match = false + } + if _, ok := m["data"]; !ok { + match = false + } + if _, ok := m["mimeType"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Image = &v + return nil + } + } + { + var v ContentBlockAudio + var match bool = true + if _, ok := m["type"]; !ok { + match = false + } + if _, ok := m["data"]; !ok { + match = false + } + if _, ok := m["mimeType"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Audio = &v + return nil + } + } + { + var v ContentBlockResourceLink + var match bool = true + if _, ok := m["type"]; !ok { + match = false + } + if _, ok := m["name"]; !ok { + match = false + } + if _, ok := m["uri"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.ResourceLink = &v + return nil + } + } + { + var v ContentBlockResource + var match bool = true + if _, ok := m["type"]; !ok { + match = false + } + if _, ok := m["resource"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Resource = &v + return nil + } + } + { + var v ContentBlockText + if json.Unmarshal(b, &v) == nil { + u.Text = &v + return nil + } + } + { + var v ContentBlockImage + if json.Unmarshal(b, &v) == nil { + u.Image = &v + return nil + } + } + { + var v ContentBlockAudio + if json.Unmarshal(b, &v) == nil { + u.Audio = &v + return nil + } + } + { + var v ContentBlockResourceLink + if json.Unmarshal(b, &v) == nil { + u.ResourceLink = &v + return nil + } + } + { + var v ContentBlockResource + if json.Unmarshal(b, &v) == nil { + u.Resource = &v + return nil } - case "image": - if c.Image == nil { - return fmt.Errorf("contentblock.image missing") + } + return nil +} +func (u ContentBlock) MarshalJSON() ([]byte, error) { + if u.Text != nil { + var m map[string]any + _b, _e := json.Marshal(*u.Text) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["type"] = "text" + { + var nm map[string]any + nm = make(map[string]any) + nm["type"] = "text" + nm["text"] = m["text"] + return json.Marshal(nm) + } + } + if u.Image != nil { + var m map[string]any + _b, _e := json.Marshal(*u.Image) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["type"] = "image" + { + var nm map[string]any + nm = make(map[string]any) + nm["type"] = "image" + nm["data"] = m["data"] + nm["mimeType"] = m["mimeType"] + if _v, _ok := m["uri"]; _ok { + nm["uri"] = _v + } + return json.Marshal(nm) + } + } + if u.Audio != nil { + var m map[string]any + _b, _e := json.Marshal(*u.Audio) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["type"] = "audio" + { + var nm map[string]any + nm = make(map[string]any) + nm["type"] = "audio" + nm["data"] = m["data"] + nm["mimeType"] = m["mimeType"] + return json.Marshal(nm) + } + } + if u.ResourceLink != nil { + var m map[string]any + _b, _e := json.Marshal(*u.ResourceLink) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["type"] = "resource_link" + { + var nm map[string]any + nm = make(map[string]any) + nm["type"] = "resource_link" + nm["name"] = m["name"] + nm["uri"] = m["uri"] + if v1, ok1 := m["description"]; ok1 { + nm["description"] = v1 + } + if v2, ok2 := m["mimeType"]; ok2 { + nm["mimeType"] = v2 + } + if v3, ok3 := m["size"]; ok3 { + nm["size"] = v3 + } + if v4, ok4 := m["title"]; ok4 { + nm["title"] = v4 + } + return json.Marshal(nm) } - case "audio": - if c.Audio == nil { - return fmt.Errorf("contentblock.audio missing") + } + if u.Resource != nil { + var m map[string]any + _b, _e := json.Marshal(*u.Resource) + if _e != nil { + return []byte{}, _e } - case "resource_link": - if c.ResourceLink == nil { - return fmt.Errorf("contentblock.resource_link missing") + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") } - case "resource": - if c.Resource == nil { - return fmt.Errorf("contentblock.resource missing") + m["type"] = "resource" + { + var nm map[string]any + nm = make(map[string]any) + nm["type"] = "resource" + nm["resource"] = m["resource"] + return json.Marshal(nm) } } + return []byte{}, nil +} + +func (u *ContentBlock) Validate() error { + var count int + if u.Text != nil { + count++ + } + if u.Image != nil { + count++ + } + if u.Audio != nil { + count++ + } + if u.ResourceLink != nil { + count++ + } + if u.Resource != nil { + count++ + } + if count != 1 { + return errors.New("ContentBlock must have exactly one variant set") + } return nil } @@ -338,16 +1199,32 @@ func (u *EmbeddedResourceResource) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &m); err != nil { return err } + if _, ok := m["text"]; ok { + var v TextResourceContents + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.TextResourceContents = &v + return nil + } + if _, ok := m["blob"]; ok { + var v BlobResourceContents + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.BlobResourceContents = &v + return nil + } { var v TextResourceContents - if err := json.Unmarshal(b, &v); err == nil { + if json.Unmarshal(b, &v) == nil { u.TextResourceContents = &v return nil } } { var v BlobResourceContents - if err := json.Unmarshal(b, &v); err == nil { + if json.Unmarshal(b, &v) == nil { u.BlobResourceContents = &v return nil } @@ -356,10 +1233,26 @@ func (u *EmbeddedResourceResource) UnmarshalJSON(b []byte) error { } func (u EmbeddedResourceResource) MarshalJSON() ([]byte, error) { if u.TextResourceContents != nil { - return json.Marshal(*u.TextResourceContents) + var m map[string]any + _b, _e := json.Marshal(*u.TextResourceContents) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) } if u.BlobResourceContents != nil { - return json.Marshal(*u.BlobResourceContents) + var m map[string]any + _b, _e := json.Marshal(*u.BlobResourceContents) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) } return []byte{}, nil } @@ -405,7 +1298,7 @@ type InitializeResponse struct { // Capabilities supported by the agent. AgentCapabilities AgentCapabilities `json:"agentCapabilities,omitempty"` // Authentication methods supported by the agent. - AuthMethods []AuthMethod `json:"authMethods,omitempty"` + AuthMethods []AuthMethod `json:"authMethods"` // The protocol version the client specified if supported by the agent, or the latest protocol version supported by the agent. The client should disconnect, if it doesn't support this version. ProtocolVersion ProtocolVersion `json:"protocolVersion"` } @@ -649,30 +1542,37 @@ func (u *RequestPermissionOutcome) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &m); err != nil { return err } + { + var disc string + if v, ok := m["outcome"]; ok { + json.Unmarshal(v, &disc) + } + switch disc { + case "cancelled": + var v RequestPermissionOutcomeCancelled + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Cancelled = &v + return nil + case "selected": + var v RequestPermissionOutcomeSelected + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Selected = &v + return nil + } + } { var v RequestPermissionOutcomeCancelled var match bool = true if _, ok := m["outcome"]; !ok { match = false } - var raw json.RawMessage - var ok bool - raw, ok = m["outcome"] - if !ok { - match = false - } - if ok { - var tmp any - if err := json.Unmarshal(raw, &tmp); err != nil { - return err - } - if fmt.Sprint(tmp) != fmt.Sprint("cancelled") { - match = false - } - } if match { - if err := json.Unmarshal(b, &v); err != nil { - return err + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") } u.Cancelled = &v return nil @@ -687,24 +1587,9 @@ func (u *RequestPermissionOutcome) UnmarshalJSON(b []byte) error { if _, ok := m["optionId"]; !ok { match = false } - var raw json.RawMessage - var ok bool - raw, ok = m["outcome"] - if !ok { - match = false - } - if ok { - var tmp any - if err := json.Unmarshal(raw, &tmp); err != nil { - return err - } - if fmt.Sprint(tmp) != fmt.Sprint("selected") { - match = false - } - } if match { - if err := json.Unmarshal(b, &v); err != nil { - return err + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") } u.Selected = &v return nil @@ -712,14 +1597,14 @@ func (u *RequestPermissionOutcome) UnmarshalJSON(b []byte) error { } { var v RequestPermissionOutcomeCancelled - if err := json.Unmarshal(b, &v); err == nil { + if json.Unmarshal(b, &v) == nil { u.Cancelled = &v return nil } } { var v RequestPermissionOutcomeSelected - if err := json.Unmarshal(b, &v); err == nil { + if json.Unmarshal(b, &v) == nil { u.Selected = &v return nil } @@ -728,14 +1613,46 @@ func (u *RequestPermissionOutcome) UnmarshalJSON(b []byte) error { } func (u RequestPermissionOutcome) MarshalJSON() ([]byte, error) { if u.Cancelled != nil { - return json.Marshal(*u.Cancelled) + var m map[string]any + _b, _e := json.Marshal(*u.Cancelled) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["outcome"] = "cancelled" + return json.Marshal(m) } if u.Selected != nil { - return json.Marshal(*u.Selected) + var m map[string]any + _b, _e := json.Marshal(*u.Selected) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["outcome"] = "selected" + return json.Marshal(m) } return []byte{}, nil } +func (u *RequestPermissionOutcome) Validate() error { + var count int + if u.Cancelled != nil { + count++ + } + if u.Selected != nil { + count++ + } + if count != 1 { + return errors.New("RequestPermissionOutcome must have exactly one variant set") + } + return nil +} + // Request for user permission to execute a tool call. Sent when the agent needs authorization before performing a sensitive operation. See protocol docs: [Requesting Permission](https://agentclientprotocol.com/protocol/tool-calls#requesting-permission) type RequestPermissionRequest struct { // Available permission options for the user to choose from. @@ -798,37 +1715,71 @@ func (v *SessionNotification) Validate() error { } // Different types of updates that can be sent during session processing. These updates provide real-time feedback about the agent's progress. See protocol docs: [Agent Reports Output](https://agentclientprotocol.com/protocol/prompt-turn#3-agent-reports-output) +// A chunk of the user's message being streamed. type SessionUpdateUserMessageChunk struct { - Content ContentBlock `json:"content"` + Content ContentBlock `json:"content"` + SessionUpdate string `json:"sessionUpdate"` } + +// A chunk of the agent's response being streamed. type SessionUpdateAgentMessageChunk struct { - Content ContentBlock `json:"content"` + Content ContentBlock `json:"content"` + SessionUpdate string `json:"sessionUpdate"` } + +// A chunk of the agent's internal reasoning being streamed. type SessionUpdateAgentThoughtChunk struct { - Content ContentBlock `json:"content"` + Content ContentBlock `json:"content"` + SessionUpdate string `json:"sessionUpdate"` } + +// Notification that a new tool call has been initiated. type SessionUpdateToolCall struct { - Content []ToolCallContent `json:"content,omitempty"` - Kind ToolKind `json:"kind,omitempty"` - Locations []ToolCallLocation `json:"locations,omitempty"` - RawInput any `json:"rawInput,omitempty"` - RawOutput any `json:"rawOutput,omitempty"` - Status ToolCallStatus `json:"status,omitempty"` - Title string `json:"title"` - ToolCallId ToolCallId `json:"toolCallId"` + // Content produced by the tool call. + Content []ToolCallContent `json:"content,omitempty"` + // The category of tool being invoked. Helps clients choose appropriate icons and UI treatment. + Kind ToolKind `json:"kind,omitempty"` + // File locations affected by this tool call. Enables "follow-along" features in clients. + Locations []ToolCallLocation `json:"locations,omitempty"` + // Raw input parameters sent to the tool. + RawInput any `json:"rawInput,omitempty"` + // Raw output returned by the tool. + RawOutput any `json:"rawOutput,omitempty"` + SessionUpdate string `json:"sessionUpdate"` + // Current execution status of the tool call. + Status ToolCallStatus `json:"status,omitempty"` + // Human-readable title describing what the tool is doing. + Title string `json:"title"` + // Unique identifier for this tool call within the session. + ToolCallId ToolCallId `json:"toolCallId"` } + +// Update on the status or results of a tool call. type SessionUpdateToolCallUpdate struct { - Content []ToolCallContent `json:"content,omitempty"` - Kind any `json:"kind,omitempty"` - Locations []ToolCallLocation `json:"locations,omitempty"` - RawInput any `json:"rawInput,omitempty"` - RawOutput any `json:"rawOutput,omitempty"` - Status any `json:"status,omitempty"` - Title *string `json:"title,omitempty"` - ToolCallId ToolCallId `json:"toolCallId"` + // Replace the content collection. + Content []ToolCallContent `json:"content,omitempty"` + // Update the tool kind. + Kind *ToolKind `json:"kind,omitempty"` + // Replace the locations collection. + Locations []ToolCallLocation `json:"locations,omitempty"` + // Update the raw input. + RawInput any `json:"rawInput,omitempty"` + // Update the raw output. + RawOutput any `json:"rawOutput,omitempty"` + SessionUpdate string `json:"sessionUpdate"` + // Update the execution status. + Status *ToolCallStatus `json:"status,omitempty"` + // Update the human-readable title. + Title *string `json:"title,omitempty"` + // The ID of the tool call being updated. + ToolCallId ToolCallId `json:"toolCallId"` } + +// The agent's execution plan for complex tasks. See protocol docs: [Agent Plan](https://agentclientprotocol.com/protocol/agent-plan) type SessionUpdatePlan struct { - Entries []PlanEntry `json:"entries"` + // The list of tasks to be accomplished. When updating a plan, the agent must send a complete list of all entries with their current status. The client replaces the entire plan with each update. + Entries []PlanEntry `json:"entries"` + SessionUpdate string `json:"sessionUpdate"` } type SessionUpdate struct { @@ -840,137 +1791,308 @@ type SessionUpdate struct { Plan *SessionUpdatePlan `json:"-"` } -func (s *SessionUpdate) UnmarshalJSON(b []byte) error { +func (u *SessionUpdate) UnmarshalJSON(b []byte) error { var m map[string]json.RawMessage if err := json.Unmarshal(b, &m); err != nil { return err } - var kind string - if v, ok := m["sessionUpdate"]; ok { - json.Unmarshal(v, &kind) + { + var disc string + if v, ok := m["sessionUpdate"]; ok { + json.Unmarshal(v, &disc) + } + switch disc { + case "user_message_chunk": + var v SessionUpdateUserMessageChunk + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.UserMessageChunk = &v + return nil + case "agent_message_chunk": + var v SessionUpdateAgentMessageChunk + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.AgentMessageChunk = &v + return nil + case "agent_thought_chunk": + var v SessionUpdateAgentThoughtChunk + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.AgentThoughtChunk = &v + return nil + case "tool_call": + var v SessionUpdateToolCall + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.ToolCall = &v + return nil + case "tool_call_update": + var v SessionUpdateToolCallUpdate + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.ToolCallUpdate = &v + return nil + case "plan": + var v SessionUpdatePlan + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Plan = &v + return nil + } } - switch kind { - case "user_message_chunk": + { var v SessionUpdateUserMessageChunk - if err := json.Unmarshal(b, &v); err != nil { - return err + var match bool = true + if _, ok := m["sessionUpdate"]; !ok { + match = false } - s.UserMessageChunk = &v - return nil - case "agent_message_chunk": + if _, ok := m["content"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.UserMessageChunk = &v + return nil + } + } + { var v SessionUpdateAgentMessageChunk - if err := json.Unmarshal(b, &v); err != nil { - return err + var match bool = true + if _, ok := m["sessionUpdate"]; !ok { + match = false } - s.AgentMessageChunk = &v - return nil - case "agent_thought_chunk": + if _, ok := m["content"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.AgentMessageChunk = &v + return nil + } + } + { var v SessionUpdateAgentThoughtChunk - if err := json.Unmarshal(b, &v); err != nil { - return err + var match bool = true + if _, ok := m["sessionUpdate"]; !ok { + match = false } - s.AgentThoughtChunk = &v - return nil - case "tool_call": + if _, ok := m["content"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.AgentThoughtChunk = &v + return nil + } + } + { var v SessionUpdateToolCall - if err := json.Unmarshal(b, &v); err != nil { - return err + var match bool = true + if _, ok := m["sessionUpdate"]; !ok { + match = false } - s.ToolCall = &v - return nil - case "tool_call_update": + if _, ok := m["toolCallId"]; !ok { + match = false + } + if _, ok := m["title"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.ToolCall = &v + return nil + } + } + { var v SessionUpdateToolCallUpdate - if err := json.Unmarshal(b, &v); err != nil { - return err + var match bool = true + if _, ok := m["sessionUpdate"]; !ok { + match = false } - s.ToolCallUpdate = &v - return nil - case "plan": + if _, ok := m["toolCallId"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.ToolCallUpdate = &v + return nil + } + } + { var v SessionUpdatePlan - if err := json.Unmarshal(b, &v); err != nil { - return err + var match bool = true + if _, ok := m["sessionUpdate"]; !ok { + match = false + } + if _, ok := m["entries"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Plan = &v + return nil + } + } + { + var v SessionUpdateUserMessageChunk + if json.Unmarshal(b, &v) == nil { + u.UserMessageChunk = &v + return nil + } + } + { + var v SessionUpdateAgentMessageChunk + if json.Unmarshal(b, &v) == nil { + u.AgentMessageChunk = &v + return nil + } + } + { + var v SessionUpdateAgentThoughtChunk + if json.Unmarshal(b, &v) == nil { + u.AgentThoughtChunk = &v + return nil + } + } + { + var v SessionUpdateToolCall + if json.Unmarshal(b, &v) == nil { + u.ToolCall = &v + return nil + } + } + { + var v SessionUpdateToolCallUpdate + if json.Unmarshal(b, &v) == nil { + u.ToolCallUpdate = &v + return nil + } + } + { + var v SessionUpdatePlan + if json.Unmarshal(b, &v) == nil { + u.Plan = &v + return nil } - s.Plan = &v - return nil } return nil } -func (s SessionUpdate) MarshalJSON() ([]byte, error) { - if s.UserMessageChunk != nil { - return json.Marshal(map[string]any{ - "content": s.UserMessageChunk.Content, - "sessionUpdate": "user_message_chunk", - }) - } - if s.AgentMessageChunk != nil { - return json.Marshal(map[string]any{ - "content": s.AgentMessageChunk.Content, - "sessionUpdate": "agent_message_chunk", - }) - } - if s.AgentThoughtChunk != nil { - return json.Marshal(map[string]any{ - "content": s.AgentThoughtChunk.Content, - "sessionUpdate": "agent_thought_chunk", - }) - } - if s.ToolCall != nil { - return json.Marshal(map[string]any{ - "content": s.ToolCall.Content, - "kind": s.ToolCall.Kind, - "locations": s.ToolCall.Locations, - "rawInput": s.ToolCall.RawInput, - "rawOutput": s.ToolCall.RawOutput, - "sessionUpdate": "tool_call", - "status": s.ToolCall.Status, - "title": s.ToolCall.Title, - "toolCallId": s.ToolCall.ToolCallId, - }) - } - if s.ToolCallUpdate != nil { - return json.Marshal(map[string]any{ - "content": s.ToolCallUpdate.Content, - "kind": s.ToolCallUpdate.Kind, - "locations": s.ToolCallUpdate.Locations, - "rawInput": s.ToolCallUpdate.RawInput, - "rawOutput": s.ToolCallUpdate.RawOutput, - "sessionUpdate": "tool_call_update", - "status": s.ToolCallUpdate.Status, - "title": s.ToolCallUpdate.Title, - "toolCallId": s.ToolCallUpdate.ToolCallId, - }) - } - if s.Plan != nil { - return json.Marshal(map[string]any{ - "entries": s.Plan.Entries, - "sessionUpdate": "plan", - }) +func (u SessionUpdate) MarshalJSON() ([]byte, error) { + if u.UserMessageChunk != nil { + var m map[string]any + _b, _e := json.Marshal(*u.UserMessageChunk) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["sessionUpdate"] = "user_message_chunk" + return json.Marshal(m) + } + if u.AgentMessageChunk != nil { + var m map[string]any + _b, _e := json.Marshal(*u.AgentMessageChunk) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["sessionUpdate"] = "agent_message_chunk" + return json.Marshal(m) + } + if u.AgentThoughtChunk != nil { + var m map[string]any + _b, _e := json.Marshal(*u.AgentThoughtChunk) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["sessionUpdate"] = "agent_thought_chunk" + return json.Marshal(m) + } + if u.ToolCall != nil { + var m map[string]any + _b, _e := json.Marshal(*u.ToolCall) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["sessionUpdate"] = "tool_call" + return json.Marshal(m) + } + if u.ToolCallUpdate != nil { + var m map[string]any + _b, _e := json.Marshal(*u.ToolCallUpdate) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["sessionUpdate"] = "tool_call_update" + return json.Marshal(m) + } + if u.Plan != nil { + var m map[string]any + _b, _e := json.Marshal(*u.Plan) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["sessionUpdate"] = "plan" + return json.Marshal(m) } return []byte{}, nil } -func (s *SessionUpdate) Validate() error { +func (u *SessionUpdate) Validate() error { var count int - if s.UserMessageChunk != nil { + if u.UserMessageChunk != nil { count++ } - if s.AgentMessageChunk != nil { + if u.AgentMessageChunk != nil { count++ } - if s.AgentThoughtChunk != nil { + if u.AgentThoughtChunk != nil { count++ } - if s.ToolCall != nil { + if u.ToolCall != nil { count++ } - if s.ToolCallUpdate != nil { + if u.ToolCallUpdate != nil { count++ } - if s.Plan != nil { + if u.Plan != nil { count++ } if count != 1 { - return fmt.Errorf("sessionupdate must have exactly one variant set") + return errors.New("SessionUpdate must have exactly one variant set") } return nil } @@ -1084,6 +2206,35 @@ func (u *ToolCallContent) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &m); err != nil { return err } + { + var disc string + if v, ok := m["type"]; ok { + json.Unmarshal(v, &disc) + } + switch disc { + case "content": + var v ToolCallContentContent + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Content = &v + return nil + case "diff": + var v ToolCallContentDiff + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Diff = &v + return nil + case "terminal": + var v ToolCallContentTerminal + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Terminal = &v + return nil + } + } { var v ToolCallContentContent var match bool = true @@ -1093,24 +2244,9 @@ func (u *ToolCallContent) UnmarshalJSON(b []byte) error { if _, ok := m["content"]; !ok { match = false } - var raw json.RawMessage - var ok bool - raw, ok = m["type"] - if !ok { - match = false - } - if ok { - var tmp any - if err := json.Unmarshal(raw, &tmp); err != nil { - return err - } - if fmt.Sprint(tmp) != fmt.Sprint("content") { - match = false - } - } if match { - if err := json.Unmarshal(b, &v); err != nil { - return err + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") } u.Content = &v return nil @@ -1128,24 +2264,9 @@ func (u *ToolCallContent) UnmarshalJSON(b []byte) error { if _, ok := m["newText"]; !ok { match = false } - var raw json.RawMessage - var ok bool - raw, ok = m["type"] - if !ok { - match = false - } - if ok { - var tmp any - if err := json.Unmarshal(raw, &tmp); err != nil { - return err - } - if fmt.Sprint(tmp) != fmt.Sprint("diff") { - match = false - } - } if match { - if err := json.Unmarshal(b, &v); err != nil { - return err + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") } u.Diff = &v return nil @@ -1160,24 +2281,9 @@ func (u *ToolCallContent) UnmarshalJSON(b []byte) error { if _, ok := m["terminalId"]; !ok { match = false } - var raw json.RawMessage - var ok bool - raw, ok = m["type"] - if !ok { - match = false - } - if ok { - var tmp any - if err := json.Unmarshal(raw, &tmp); err != nil { - return err - } - if fmt.Sprint(tmp) != fmt.Sprint("terminal") { - match = false - } - } if match { - if err := json.Unmarshal(b, &v); err != nil { - return err + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") } u.Terminal = &v return nil @@ -1185,21 +2291,21 @@ func (u *ToolCallContent) UnmarshalJSON(b []byte) error { } { var v ToolCallContentContent - if err := json.Unmarshal(b, &v); err == nil { + if json.Unmarshal(b, &v) == nil { u.Content = &v return nil } } { var v ToolCallContentDiff - if err := json.Unmarshal(b, &v); err == nil { + if json.Unmarshal(b, &v) == nil { u.Diff = &v return nil } } { var v ToolCallContentTerminal - if err := json.Unmarshal(b, &v); err == nil { + if json.Unmarshal(b, &v) == nil { u.Terminal = &v return nil } @@ -1208,17 +2314,61 @@ func (u *ToolCallContent) UnmarshalJSON(b []byte) error { } func (u ToolCallContent) MarshalJSON() ([]byte, error) { if u.Content != nil { - return json.Marshal(*u.Content) + var m map[string]any + _b, _e := json.Marshal(*u.Content) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["type"] = "content" + return json.Marshal(m) } if u.Diff != nil { - return json.Marshal(*u.Diff) + var m map[string]any + _b, _e := json.Marshal(*u.Diff) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["type"] = "diff" + return json.Marshal(m) } if u.Terminal != nil { - return json.Marshal(*u.Terminal) + var m map[string]any + _b, _e := json.Marshal(*u.Terminal) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["type"] = "terminal" + return json.Marshal(m) } return []byte{}, nil } +func (u *ToolCallContent) Validate() error { + var count int + if u.Content != nil { + count++ + } + if u.Diff != nil { + count++ + } + if u.Terminal != nil { + count++ + } + if count != 1 { + return errors.New("ToolCallContent must have exactly one variant set") + } + return nil +} + // Unique identifier for a tool call within a session. type ToolCallId string diff --git a/package.json b/package.json index 8b260d9..a3e5263 100644 --- a/package.json +++ b/package.json @@ -36,7 +36,7 @@ "test:ts:watch": "vitest", "generate:json-schema": "cd rust && cargo run --bin generate --features unstable", "generate:ts-schema": "node typescript/generate.js", - "generate:go": "cd go/cmd/generate && go run . && cd ../.. && go fmt ./...", + "generate:go": "cd go/cmd/generate && env -u GOPATH -u GOMODCACHE go run . && cd ../.. && env -u GOPATH -u GOMODCACHE go run mvdan.cc/gofumpt@latest -w .", "generate": "npm run generate:json-schema && npm run generate:ts-schema && npm run generate:go && npm run format", "build": "npm run generate && tsc", "format": "prettier --write . && cargo fmt", From 0710c38ffaa1c74721fb6c8c48890d0dfe189c90 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Tue, 2 Sep 2025 12:42:50 +0200 Subject: [PATCH 13/22] test: enhance JSON parity tests with constructor function validation Change-Id: Idf60f2e210f012cb035aa4de5680a1fbcf84d4c5 Signed-off-by: Thomas Kosiewski --- go/json_parity_test.go | 107 ++++++++++++++++++++++++++++------------- 1 file changed, 73 insertions(+), 34 deletions(-) diff --git a/go/json_parity_test.go b/go/json_parity_test.go index 4dc48c8..aa238e5 100644 --- a/go/json_parity_test.go +++ b/go/json_parity_test.go @@ -32,8 +32,10 @@ func mustReadGolden(t *testing.T, name string) []byte { return b } -// Generic golden runner for a specific type T -func runGolden[T any](t *testing.T, build func() T) { +// Generic golden runner for a specific type T. Accepts one or more builders and +// asserts that all of them serialize to the same golden file derived from the +// subtest name. +func runGolden[T any](t *testing.T, builds ...func() T) { t.Helper() // Use the current subtest name; expect pattern like "/". name := t.Name() @@ -42,15 +44,17 @@ func runGolden[T any](t *testing.T, build func() T) { base = base[i+1:] } want := mustReadGolden(t, base+".json") - // Marshal from constructed value and compare with golden JSON. - got, err := json.Marshal(build()) - if err != nil { - t.Fatalf("marshal %s: %v", base, err) - } - if ok, ga, gw := equalJSON(got, want); !ok { - t.Fatalf("%s marshal mismatch\n got: %s\nwant: %s", base, ga, gw) + // Forward serialization for each builder matches the same golden JSON. + for _, build := range builds { + got, err := json.Marshal(build()) + if err != nil { + t.Fatalf("marshal %s: %v", base, err) + } + if ok, ga, gw := equalJSON(got, want); !ok { + t.Fatalf("%s marshal mismatch\n got: %s\nwant: %s", base, ga, gw) + } } - // Unmarshal golden into type, then marshal again and compare. + // Unmarshal golden into type, then marshal again and compare (one round-trip check). var v T if err := json.Unmarshal(want, &v); err != nil { t.Fatalf("unmarshal %s: %v", base, err) @@ -109,47 +113,82 @@ func TestJSONGolden_ToolCallContent(t *testing.T) { func TestJSONGolden_RequestPermissionOutcome(t *testing.T) { t.Run("permission_outcome_selected", func(t *testing.T) { - runGolden(t, func() RequestPermissionOutcome { - return RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{Outcome: "selected", OptionId: "allow-once"}} - }) + runGolden(t, + func() RequestPermissionOutcome { + return RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{Outcome: "selected", OptionId: "allow-once"}} + }, + func() RequestPermissionOutcome { + return NewRequestPermissionOutcomeSelected("allow-once") + }, + ) }) t.Run("permission_outcome_cancelled", func(t *testing.T) { - runGolden(t, func() RequestPermissionOutcome { - return RequestPermissionOutcome{Cancelled: &RequestPermissionOutcomeCancelled{Outcome: "cancelled"}} - }) + runGolden(t, + func() RequestPermissionOutcome { + return RequestPermissionOutcome{Cancelled: &RequestPermissionOutcomeCancelled{Outcome: "cancelled"}} + }, + func() RequestPermissionOutcome { return NewRequestPermissionOutcomeCancelled() }, + ) }) } func TestJSONGolden_SessionUpdates(t *testing.T) { t.Run("session_update_user_message_chunk", func(t *testing.T) { - runGolden(t, func() SessionUpdate { - return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: TextBlock("What's the capital of France?")}} - }) + runGolden(t, + func() SessionUpdate { + return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: TextBlock("What's the capital of France?")}} + }, + func() SessionUpdate { return UpdateUserMessageText("What's the capital of France?") }, + ) }) t.Run("session_update_agent_message_chunk", func(t *testing.T) { - runGolden(t, func() SessionUpdate { - return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: TextBlock("The capital of France is Paris.")}} - }) + runGolden(t, + func() SessionUpdate { + return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: TextBlock("The capital of France is Paris.")}} + }, + func() SessionUpdate { return UpdateAgentMessageText("The capital of France is Paris.") }, + ) }) t.Run("session_update_agent_thought_chunk", func(t *testing.T) { - runGolden(t, func() SessionUpdate { - return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{Content: TextBlock("Thinking about best approach...")}} - }) + runGolden(t, + func() SessionUpdate { + return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{Content: TextBlock("Thinking about best approach...")}} + }, + func() SessionUpdate { return UpdateAgentThoughtText("Thinking about best approach...") }, + ) }) t.Run("session_update_plan", func(t *testing.T) { - runGolden(t, func() SessionUpdate { - return SessionUpdate{Plan: &SessionUpdatePlan{Entries: []PlanEntry{{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, {Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}}}} - }) + runGolden(t, + func() SessionUpdate { + return SessionUpdate{Plan: &SessionUpdatePlan{Entries: []PlanEntry{{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, {Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}}}} + }, + func() SessionUpdate { + return UpdatePlan( + PlanEntry{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, + PlanEntry{Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}, + ) + }, + ) }) t.Run("session_update_tool_call", func(t *testing.T) { - runGolden(t, func() SessionUpdate { - return SessionUpdate{ToolCall: &SessionUpdateToolCall{ToolCallId: "call_001", Title: "Reading configuration file", Kind: ToolKindRead, Status: ToolCallStatusPending}} - }) + runGolden(t, + func() SessionUpdate { + return SessionUpdate{ToolCall: &SessionUpdateToolCall{ToolCallId: "call_001", Title: "Reading configuration file", Kind: ToolKindRead, Status: ToolCallStatusPending}} + }, + func() SessionUpdate { + return StartToolCall("call_001", "Reading configuration file", WithStartKind(ToolKindRead), WithStartStatus(ToolCallStatusPending)) + }, + ) }) t.Run("session_update_tool_call_update_content", func(t *testing.T) { - runGolden(t, func() SessionUpdate { - return SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ToolCallId: "call_001", Status: Ptr(ToolCallStatusInProgress), Content: []ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))}}} - }) + runGolden(t, + func() SessionUpdate { + return SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ToolCallId: "call_001", Status: Ptr(ToolCallStatusInProgress), Content: []ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))}}} + }, + func() SessionUpdate { + return UpdateToolCall("call_001", WithUpdateStatus(ToolCallStatusInProgress), WithUpdateContent([]ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))})) + }, + ) }) } From cd12c485c8da8f9bdc59f3b7e87730042d81a1e3 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Tue, 2 Sep 2025 13:01:30 +0200 Subject: [PATCH 14/22] test: refactor JSON parity tests with enhanced constructor function validation Change-Id: I9363d69a0842e656ecbfeeba19f379206407d924 Signed-off-by: Thomas Kosiewski --- go/json_parity_test.go | 425 ++++++++++-------- .../session_update_tool_call_edit.json | 17 + ...n_update_tool_call_locations_rawinput.json | 14 + .../session_update_tool_call_minimal.json | 6 + .../session_update_tool_call_read.json | 16 + ...ssion_update_tool_call_update_minimal.json | 5 + ...n_update_tool_call_update_more_fields.json | 28 ++ .../json_golden/tool_content_diff_no_old.json | 6 + .../json_golden/tool_content_terminal.json | 5 + 9 files changed, 336 insertions(+), 186 deletions(-) create mode 100644 go/testdata/json_golden/session_update_tool_call_edit.json create mode 100644 go/testdata/json_golden/session_update_tool_call_locations_rawinput.json create mode 100644 go/testdata/json_golden/session_update_tool_call_minimal.json create mode 100644 go/testdata/json_golden/session_update_tool_call_read.json create mode 100644 go/testdata/json_golden/session_update_tool_call_update_minimal.json create mode 100644 go/testdata/json_golden/session_update_tool_call_update_more_fields.json create mode 100644 go/testdata/json_golden/tool_content_diff_no_old.json create mode 100644 go/testdata/json_golden/tool_content_terminal.json diff --git a/go/json_parity_test.go b/go/json_parity_test.go index aa238e5..b2d4590 100644 --- a/go/json_parity_test.go +++ b/go/json_parity_test.go @@ -33,216 +33,269 @@ func mustReadGolden(t *testing.T, name string) []byte { } // Generic golden runner for a specific type T. Accepts one or more builders and -// asserts that all of them serialize to the same golden file derived from the -// subtest name. -func runGolden[T any](t *testing.T, builds ...func() T) { - t.Helper() - // Use the current subtest name; expect pattern like "/". - name := t.Name() - base := name - if i := strings.LastIndex(base, "/"); i >= 0 { - base = base[i+1:] - } - want := mustReadGolden(t, base+".json") - // Forward serialization for each builder matches the same golden JSON. - for _, build := range builds { - got, err := json.Marshal(build()) +// returns a subtest function that asserts they all serialize to the same golden +// file derived from the subtest name. +func runGolden[T any](builds ...func() T) func(t *testing.T) { + return func(t *testing.T) { + t.Helper() + // Use the current subtest name; expect pattern like "/". + name := t.Name() + base := name + if i := strings.LastIndex(base, "/"); i >= 0 { + base = base[i+1:] + } + want := mustReadGolden(t, base+".json") + // Forward serialization for each builder matches the same golden JSON. + for _, build := range builds { + got, err := json.Marshal(build()) + if err != nil { + t.Fatalf("marshal %s: %v", base, err) + } + if ok, ga, gw := equalJSON(got, want); !ok { + t.Fatalf("%s marshal mismatch\n got: %s\nwant: %s", base, ga, gw) + } + } + // Unmarshal golden into type, then marshal again and compare (one round-trip check). + var v T + if err := json.Unmarshal(want, &v); err != nil { + t.Fatalf("unmarshal %s: %v", base, err) + } + round, err := json.Marshal(v) if err != nil { - t.Fatalf("marshal %s: %v", base, err) + t.Fatalf("re-marshal %s: %v", base, err) } - if ok, ga, gw := equalJSON(got, want); !ok { - t.Fatalf("%s marshal mismatch\n got: %s\nwant: %s", base, ga, gw) + if ok, ga, gw := equalJSON(round, want); !ok { + t.Fatalf("%s round-trip mismatch\n got: %s\nwant: %s", base, ga, gw) } } - // Unmarshal golden into type, then marshal again and compare (one round-trip check). - var v T - if err := json.Unmarshal(want, &v); err != nil { - t.Fatalf("unmarshal %s: %v", base, err) - } - round, err := json.Marshal(v) - if err != nil { - t.Fatalf("re-marshal %s: %v", base, err) - } - if ok, ga, gw := equalJSON(round, want); !ok { - t.Fatalf("%s round-trip mismatch\n got: %s\nwant: %s", base, ga, gw) - } } func TestJSONGolden_ContentBlocks(t *testing.T) { - t.Run("content_text", func(t *testing.T) { - runGolden(t, func() ContentBlock { return TextBlock("What's the weather like today?") }) - }) - t.Run("content_image", func(t *testing.T) { - runGolden(t, func() ContentBlock { return ImageBlock("iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB...", "image/png") }) - }) - t.Run("content_audio", func(t *testing.T) { - runGolden(t, func() ContentBlock { return AudioBlock("UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB...", "audio/wav") }) - }) - t.Run("content_resource_text", func(t *testing.T) { - runGolden(t, func() ContentBlock { + t.Run("content_text", runGolden( + func() ContentBlock { return TextBlock("What's the weather like today?") }, + func() ContentBlock { return NewContentBlockText("What's the weather like today?") }, + )) + t.Run("content_image", runGolden( + func() ContentBlock { return ImageBlock("iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB...", "image/png") }, + func() ContentBlock { return NewContentBlockImage("iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB...", "image/png") }, + )) + t.Run("content_audio", runGolden( + func() ContentBlock { return AudioBlock("UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB...", "audio/wav") }, + func() ContentBlock { + return NewContentBlockAudio("UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB...", "audio/wav") + }, + )) + t.Run("content_resource_text", runGolden( + func() ContentBlock { res := EmbeddedResourceResource{TextResourceContents: &TextResourceContents{Uri: "file:///home/user/script.py", MimeType: Ptr("text/x-python"), Text: "def hello():\n print('Hello, world!')"}} return ResourceBlock(EmbeddedResource{Resource: res}) - }) - }) - t.Run("content_resource_blob", func(t *testing.T) { - runGolden(t, func() ContentBlock { + }, + func() ContentBlock { + res := EmbeddedResourceResource{TextResourceContents: &TextResourceContents{Uri: "file:///home/user/script.py", MimeType: Ptr("text/x-python"), Text: "def hello():\n print('Hello, world!')"}} + return NewContentBlockResource(res) + }, + )) + t.Run("content_resource_blob", runGolden( + func() ContentBlock { res := EmbeddedResourceResource{BlobResourceContents: &BlobResourceContents{Uri: "file:///home/user/document.pdf", MimeType: Ptr("application/pdf"), Blob: ""}} return ResourceBlock(EmbeddedResource{Resource: res}) - }) - }) - t.Run("content_resource_link", func(t *testing.T) { - runGolden(t, func() ContentBlock { + }, + func() ContentBlock { + res := EmbeddedResourceResource{BlobResourceContents: &BlobResourceContents{Uri: "file:///home/user/document.pdf", MimeType: Ptr("application/pdf"), Blob: ""}} + return NewContentBlockResource(res) + }, + )) + t.Run("content_resource_link", runGolden( + func() ContentBlock { mt := "application/pdf" sz := 1024000 return ContentBlock{ResourceLink: &ContentBlockResourceLink{Type: "resource_link", Uri: "file:///home/user/document.pdf", Name: "document.pdf", MimeType: &mt, Size: &sz}} - }) - }) + }, + func() ContentBlock { + cb := ResourceLinkBlock("document.pdf", "file:///home/user/document.pdf") + mt := "application/pdf" + sz := 1024000 + cb.ResourceLink.MimeType = &mt + cb.ResourceLink.Size = &sz + return cb + }, + func() ContentBlock { + cb := NewContentBlockResourceLink("document.pdf", "file:///home/user/document.pdf") + mt := "application/pdf" + sz := 1024000 + cb.ResourceLink.MimeType = &mt + cb.ResourceLink.Size = &sz + return cb + }, + )) } func TestJSONGolden_ToolCallContent(t *testing.T) { - t.Run("tool_content_content_text", func(t *testing.T) { - runGolden(t, func() ToolCallContent { return ToolContent(TextBlock("Analysis complete. Found 3 issues.")) }) - }) - t.Run("tool_content_diff", func(t *testing.T) { - runGolden(t, func() ToolCallContent { - old := "{\n \"debug\": false\n}" - return ToolDiffContent("/home/user/project/src/config.json", "{\n \"debug\": true\n}", old) - }) - }) + t.Run("tool_content_content_text", runGolden( + func() ToolCallContent { return ToolContent(TextBlock("Analysis complete. Found 3 issues.")) }, + func() ToolCallContent { + return NewToolCallContentContent(TextBlock("Analysis complete. Found 3 issues.")) + }, + )) + t.Run("tool_content_diff", runGolden(func() ToolCallContent { + old := "{\n \"debug\": false\n}" + return ToolDiffContent("/home/user/project/src/config.json", "{\n \"debug\": true\n}", old) + })) + t.Run("tool_content_diff_no_old", runGolden( + func() ToolCallContent { + return ToolDiffContent("/home/user/project/src/config.json", "{\n \"debug\": true\n}") + }, + func() ToolCallContent { + return NewToolCallContentDiff("/home/user/project/src/config.json", "{\n \"debug\": true\n}") + }, + )) + t.Run("tool_content_terminal", runGolden( + func() ToolCallContent { return ToolTerminalRef("term_001") }, + func() ToolCallContent { return NewToolCallContentTerminal("term_001") }, + )) } func TestJSONGolden_RequestPermissionOutcome(t *testing.T) { - t.Run("permission_outcome_selected", func(t *testing.T) { - runGolden(t, - func() RequestPermissionOutcome { - return RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{Outcome: "selected", OptionId: "allow-once"}} - }, - func() RequestPermissionOutcome { - return NewRequestPermissionOutcomeSelected("allow-once") - }, - ) - }) - t.Run("permission_outcome_cancelled", func(t *testing.T) { - runGolden(t, - func() RequestPermissionOutcome { - return RequestPermissionOutcome{Cancelled: &RequestPermissionOutcomeCancelled{Outcome: "cancelled"}} - }, - func() RequestPermissionOutcome { return NewRequestPermissionOutcomeCancelled() }, - ) - }) + t.Run("permission_outcome_selected", runGolden( + func() RequestPermissionOutcome { + return RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{Outcome: "selected", OptionId: "allow-once"}} + }, + func() RequestPermissionOutcome { + return NewRequestPermissionOutcomeSelected("allow-once") + }, + )) + t.Run("permission_outcome_cancelled", runGolden( + func() RequestPermissionOutcome { + return RequestPermissionOutcome{Cancelled: &RequestPermissionOutcomeCancelled{Outcome: "cancelled"}} + }, + func() RequestPermissionOutcome { return NewRequestPermissionOutcomeCancelled() }, + )) } func TestJSONGolden_SessionUpdates(t *testing.T) { - t.Run("session_update_user_message_chunk", func(t *testing.T) { - runGolden(t, - func() SessionUpdate { - return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: TextBlock("What's the capital of France?")}} - }, - func() SessionUpdate { return UpdateUserMessageText("What's the capital of France?") }, - ) - }) - t.Run("session_update_agent_message_chunk", func(t *testing.T) { - runGolden(t, - func() SessionUpdate { - return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: TextBlock("The capital of France is Paris.")}} - }, - func() SessionUpdate { return UpdateAgentMessageText("The capital of France is Paris.") }, - ) - }) - t.Run("session_update_agent_thought_chunk", func(t *testing.T) { - runGolden(t, - func() SessionUpdate { - return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{Content: TextBlock("Thinking about best approach...")}} - }, - func() SessionUpdate { return UpdateAgentThoughtText("Thinking about best approach...") }, - ) - }) - t.Run("session_update_plan", func(t *testing.T) { - runGolden(t, - func() SessionUpdate { - return SessionUpdate{Plan: &SessionUpdatePlan{Entries: []PlanEntry{{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, {Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}}}} - }, - func() SessionUpdate { - return UpdatePlan( - PlanEntry{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, - PlanEntry{Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}, - ) - }, - ) - }) - t.Run("session_update_tool_call", func(t *testing.T) { - runGolden(t, - func() SessionUpdate { - return SessionUpdate{ToolCall: &SessionUpdateToolCall{ToolCallId: "call_001", Title: "Reading configuration file", Kind: ToolKindRead, Status: ToolCallStatusPending}} - }, - func() SessionUpdate { - return StartToolCall("call_001", "Reading configuration file", WithStartKind(ToolKindRead), WithStartStatus(ToolCallStatusPending)) - }, - ) - }) - t.Run("session_update_tool_call_update_content", func(t *testing.T) { - runGolden(t, - func() SessionUpdate { - return SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ToolCallId: "call_001", Status: Ptr(ToolCallStatusInProgress), Content: []ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))}}} - }, - func() SessionUpdate { - return UpdateToolCall("call_001", WithUpdateStatus(ToolCallStatusInProgress), WithUpdateContent([]ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))})) - }, - ) - }) + t.Run("session_update_user_message_chunk", runGolden( + func() SessionUpdate { + return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: TextBlock("What's the capital of France?")}} + }, + func() SessionUpdate { return UpdateUserMessageText("What's the capital of France?") }, + func() SessionUpdate { + return NewSessionUpdateUserMessageChunk(TextBlock("What's the capital of France?")) + }, + )) + t.Run("session_update_agent_message_chunk", runGolden( + func() SessionUpdate { + return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: TextBlock("The capital of France is Paris.")}} + }, + func() SessionUpdate { return UpdateAgentMessageText("The capital of France is Paris.") }, + func() SessionUpdate { + return NewSessionUpdateAgentMessageChunk(TextBlock("The capital of France is Paris.")) + }, + )) + t.Run("session_update_agent_thought_chunk", runGolden( + func() SessionUpdate { + return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{Content: TextBlock("Thinking about best approach...")}} + }, + func() SessionUpdate { return UpdateAgentThoughtText("Thinking about best approach...") }, + func() SessionUpdate { + return NewSessionUpdateAgentThoughtChunk(TextBlock("Thinking about best approach...")) + }, + )) + t.Run("session_update_plan", runGolden( + func() SessionUpdate { + return SessionUpdate{Plan: &SessionUpdatePlan{Entries: []PlanEntry{{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, {Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}}}} + }, + func() SessionUpdate { + return UpdatePlan( + PlanEntry{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, + PlanEntry{Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}, + ) + }, + func() SessionUpdate { + return NewSessionUpdatePlan([]PlanEntry{{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, {Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}}) + }, + )) + t.Run("session_update_tool_call", runGolden( + func() SessionUpdate { + return SessionUpdate{ToolCall: &SessionUpdateToolCall{ToolCallId: "call_001", Title: "Reading configuration file", Kind: ToolKindRead, Status: ToolCallStatusPending}} + }, + func() SessionUpdate { + return StartToolCall("call_001", "Reading configuration file", WithStartKind(ToolKindRead), WithStartStatus(ToolCallStatusPending)) + }, + )) + t.Run("session_update_tool_call_minimal", runGolden( + func() SessionUpdate { return NewSessionUpdateToolCall("call_001", "Reading configuration file") }, + )) + t.Run("session_update_tool_call_read", runGolden( + func() SessionUpdate { + return StartReadToolCall("call_001", "Reading configuration file", "/home/user/project/src/config.json") + }, + )) + t.Run("session_update_tool_call_edit", runGolden( + func() SessionUpdate { + return StartEditToolCall("call_003", "Apply edit", "/home/user/project/src/config.json", "print('hello')") + }, + )) + t.Run("session_update_tool_call_locations_rawinput", runGolden( + func() SessionUpdate { + return StartToolCall("call_lr", "Tracking file", WithStartLocations([]ToolCallLocation{{Path: "/home/user/project/src/config.json"}})) + }, + )) + t.Run("session_update_tool_call_update_content", runGolden( + func() SessionUpdate { + return SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ToolCallId: "call_001", Status: Ptr(ToolCallStatusInProgress), Content: []ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))}}} + }, + func() SessionUpdate { + return UpdateToolCall("call_001", WithUpdateStatus(ToolCallStatusInProgress), WithUpdateContent([]ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))})) + }, + )) + t.Run("session_update_tool_call_update_minimal", runGolden( + func() SessionUpdate { return NewSessionUpdateToolCallUpdate("call_001") }, + )) + t.Run("session_update_tool_call_update_more_fields", runGolden( + func() SessionUpdate { + return UpdateToolCall( + "call_010", + WithUpdateTitle("Processing changes"), + WithUpdateKind(ToolKindEdit), + WithUpdateStatus(ToolCallStatusCompleted), + WithUpdateLocations([]ToolCallLocation{{Path: "/home/user/project/src/config.json"}}), + WithUpdateRawInput(map[string]any{"path": "/home/user/project/src/config.json"}), + WithUpdateRawOutput(map[string]any{"result": "ok"}), + WithUpdateContent([]ToolCallContent{ToolContent(TextBlock("Edit completed."))}), + ) + }, + )) } func TestJSONGolden_MethodPayloads(t *testing.T) { - t.Run("initialize_request", func(t *testing.T) { - runGolden(t, func() InitializeRequest { - return InitializeRequest{ProtocolVersion: 1, ClientCapabilities: ClientCapabilities{Fs: FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}} - }) - }) - t.Run("initialize_response", func(t *testing.T) { - runGolden(t, func() InitializeResponse { - return InitializeResponse{ProtocolVersion: 1, AgentCapabilities: AgentCapabilities{LoadSession: true, PromptCapabilities: PromptCapabilities{Image: true, Audio: true, EmbeddedContext: true}}, AuthMethods: []AuthMethod{}} - }) - }) - t.Run("new_session_request", func(t *testing.T) { - runGolden(t, func() NewSessionRequest { - return NewSessionRequest{Cwd: "/home/user/project", McpServers: []McpServer{{Name: "filesystem", Command: "/path/to/mcp-server", Args: []string{"--stdio"}, Env: []EnvVariable{}}}} - }) - }) - t.Run("new_session_response", func(t *testing.T) { - runGolden(t, func() NewSessionResponse { return NewSessionResponse{SessionId: "sess_abc123def456"} }) - }) - t.Run("prompt_request", func(t *testing.T) { - runGolden(t, func() PromptRequest { - return PromptRequest{SessionId: "sess_abc123def456", Prompt: []ContentBlock{TextBlock("Can you analyze this code for potential issues?"), ResourceBlock(EmbeddedResource{Resource: EmbeddedResourceResource{TextResourceContents: &TextResourceContents{Uri: "file:///home/user/project/main.py", MimeType: Ptr("text/x-python"), Text: "def process_data(items):\n for item in items:\n print(item)"}}})}} - }) - }) - t.Run("fs_read_text_file_request", func(t *testing.T) { - runGolden(t, func() ReadTextFileRequest { - line, limit := 10, 50 - return ReadTextFileRequest{SessionId: "sess_abc123def456", Path: "/home/user/project/src/main.py", Line: &line, Limit: &limit} - }) - }) - t.Run("fs_read_text_file_response", func(t *testing.T) { - runGolden(t, func() ReadTextFileResponse { - return ReadTextFileResponse{Content: "def hello_world():\n print('Hello, world!')\n"} - }) - }) - t.Run("fs_write_text_file_request", func(t *testing.T) { - runGolden(t, func() WriteTextFileRequest { - return WriteTextFileRequest{SessionId: "sess_abc123def456", Path: "/home/user/project/config.json", Content: "{\n \"debug\": true,\n \"version\": \"1.0.0\"\n}"} - }) - }) - t.Run("request_permission_request", func(t *testing.T) { - runGolden(t, func() RequestPermissionRequest { - return RequestPermissionRequest{SessionId: "sess_abc123def456", ToolCall: ToolCallUpdate{ToolCallId: "call_001"}, Options: []PermissionOption{{OptionId: "allow-once", Name: "Allow once", Kind: PermissionOptionKindAllowOnce}, {OptionId: "reject-once", Name: "Reject", Kind: PermissionOptionKindRejectOnce}}} - }) - }) - t.Run("request_permission_response_selected", func(t *testing.T) { - runGolden(t, func() RequestPermissionResponse { - return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{Outcome: "selected", OptionId: "allow-once"}}} - }) - }) - t.Run("cancel_notification", func(t *testing.T) { - runGolden(t, func() CancelNotification { return CancelNotification{SessionId: "sess_abc123def456"} }) - }) + t.Run("initialize_request", runGolden(func() InitializeRequest { + return InitializeRequest{ProtocolVersion: 1, ClientCapabilities: ClientCapabilities{Fs: FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}} + })) + t.Run("initialize_response", runGolden(func() InitializeResponse { + return InitializeResponse{ProtocolVersion: 1, AgentCapabilities: AgentCapabilities{LoadSession: true, PromptCapabilities: PromptCapabilities{Image: true, Audio: true, EmbeddedContext: true}}, AuthMethods: []AuthMethod{}} + })) + t.Run("new_session_request", runGolden(func() NewSessionRequest { + return NewSessionRequest{Cwd: "/home/user/project", McpServers: []McpServer{{Name: "filesystem", Command: "/path/to/mcp-server", Args: []string{"--stdio"}, Env: []EnvVariable{}}}} + })) + t.Run("new_session_response", runGolden(func() NewSessionResponse { return NewSessionResponse{SessionId: "sess_abc123def456"} })) + t.Run("prompt_request", runGolden(func() PromptRequest { + return PromptRequest{SessionId: "sess_abc123def456", Prompt: []ContentBlock{TextBlock("Can you analyze this code for potential issues?"), ResourceBlock(EmbeddedResource{Resource: EmbeddedResourceResource{TextResourceContents: &TextResourceContents{Uri: "file:///home/user/project/main.py", MimeType: Ptr("text/x-python"), Text: "def process_data(items):\n for item in items:\n print(item)"}}})}} + })) + t.Run("fs_read_text_file_request", runGolden(func() ReadTextFileRequest { + line, limit := 10, 50 + return ReadTextFileRequest{SessionId: "sess_abc123def456", Path: "/home/user/project/src/main.py", Line: &line, Limit: &limit} + })) + t.Run("fs_read_text_file_response", runGolden(func() ReadTextFileResponse { + return ReadTextFileResponse{Content: "def hello_world():\n print('Hello, world!')\n"} + })) + t.Run("fs_write_text_file_request", runGolden(func() WriteTextFileRequest { + return WriteTextFileRequest{SessionId: "sess_abc123def456", Path: "/home/user/project/config.json", Content: "{\n \"debug\": true,\n \"version\": \"1.0.0\"\n}"} + })) + t.Run("request_permission_request", runGolden(func() RequestPermissionRequest { + return RequestPermissionRequest{SessionId: "sess_abc123def456", ToolCall: ToolCallUpdate{ToolCallId: "call_001"}, Options: []PermissionOption{{OptionId: "allow-once", Name: "Allow once", Kind: PermissionOptionKindAllowOnce}, {OptionId: "reject-once", Name: "Reject", Kind: PermissionOptionKindRejectOnce}}} + })) + t.Run("request_permission_response_selected", runGolden(func() RequestPermissionResponse { + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{Outcome: "selected", OptionId: "allow-once"}}} + })) + t.Run("cancel_notification", runGolden(func() CancelNotification { return CancelNotification{SessionId: "sess_abc123def456"} })) } diff --git a/go/testdata/json_golden/session_update_tool_call_edit.json b/go/testdata/json_golden/session_update_tool_call_edit.json new file mode 100644 index 0000000..6ddd93d --- /dev/null +++ b/go/testdata/json_golden/session_update_tool_call_edit.json @@ -0,0 +1,17 @@ +{ + "sessionUpdate": "tool_call", + "toolCallId": "call_003", + "title": "Apply edit", + "kind": "edit", + "status": "pending", + "locations": [ + { + "path": "/home/user/project/src/config.json" + } + ], + "rawInput": { + "path": "/home/user/project/src/config.json", + "content": "print('hello')" + } +} + diff --git a/go/testdata/json_golden/session_update_tool_call_locations_rawinput.json b/go/testdata/json_golden/session_update_tool_call_locations_rawinput.json new file mode 100644 index 0000000..d76507a --- /dev/null +++ b/go/testdata/json_golden/session_update_tool_call_locations_rawinput.json @@ -0,0 +1,14 @@ +{ + "sessionUpdate": "tool_call", + "toolCallId": "call_lr", + "title": "Tracking file", + "locations": [ + { + "path": "/home/user/project/src/config.json" + } + ], + "rawInput": { + "path": "/home/user/project/src/config.json" + } +} + diff --git a/go/testdata/json_golden/session_update_tool_call_minimal.json b/go/testdata/json_golden/session_update_tool_call_minimal.json new file mode 100644 index 0000000..af5edd7 --- /dev/null +++ b/go/testdata/json_golden/session_update_tool_call_minimal.json @@ -0,0 +1,6 @@ +{ + "sessionUpdate": "tool_call", + "toolCallId": "call_001", + "title": "Reading configuration file" +} + diff --git a/go/testdata/json_golden/session_update_tool_call_read.json b/go/testdata/json_golden/session_update_tool_call_read.json new file mode 100644 index 0000000..bfc2008 --- /dev/null +++ b/go/testdata/json_golden/session_update_tool_call_read.json @@ -0,0 +1,16 @@ +{ + "sessionUpdate": "tool_call", + "toolCallId": "call_001", + "title": "Reading configuration file", + "kind": "read", + "status": "pending", + "locations": [ + { + "path": "/home/user/project/src/config.json" + } + ], + "rawInput": { + "path": "/home/user/project/src/config.json" + } +} + diff --git a/go/testdata/json_golden/session_update_tool_call_update_minimal.json b/go/testdata/json_golden/session_update_tool_call_update_minimal.json new file mode 100644 index 0000000..4493e55 --- /dev/null +++ b/go/testdata/json_golden/session_update_tool_call_update_minimal.json @@ -0,0 +1,5 @@ +{ + "sessionUpdate": "tool_call_update", + "toolCallId": "call_001" +} + diff --git a/go/testdata/json_golden/session_update_tool_call_update_more_fields.json b/go/testdata/json_golden/session_update_tool_call_update_more_fields.json new file mode 100644 index 0000000..1469cae --- /dev/null +++ b/go/testdata/json_golden/session_update_tool_call_update_more_fields.json @@ -0,0 +1,28 @@ +{ + "sessionUpdate": "tool_call_update", + "toolCallId": "call_010", + "title": "Processing changes", + "kind": "edit", + "status": "completed", + "locations": [ + { + "path": "/home/user/project/src/config.json" + } + ], + "rawInput": { + "path": "/home/user/project/src/config.json" + }, + "rawOutput": { + "result": "ok" + }, + "content": [ + { + "type": "content", + "content": { + "type": "text", + "text": "Edit completed." + } + } + ] +} + diff --git a/go/testdata/json_golden/tool_content_diff_no_old.json b/go/testdata/json_golden/tool_content_diff_no_old.json new file mode 100644 index 0000000..e14cbe9 --- /dev/null +++ b/go/testdata/json_golden/tool_content_diff_no_old.json @@ -0,0 +1,6 @@ +{ + "type": "diff", + "path": "/home/user/project/src/config.json", + "newText": "{\n \"debug\": true\n}" +} + diff --git a/go/testdata/json_golden/tool_content_terminal.json b/go/testdata/json_golden/tool_content_terminal.json new file mode 100644 index 0000000..387b7d8 --- /dev/null +++ b/go/testdata/json_golden/tool_content_terminal.json @@ -0,0 +1,5 @@ +{ + "type": "terminal", + "terminalId": "term_001" +} + From bdd623def0ca24eb3c3e4154a43d73583ce6b827 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Tue, 2 Sep 2025 13:22:11 +0200 Subject: [PATCH 15/22] refactor: improve code generation with cleaner function signatures and JSON marshaling Change-Id: Ia6003b4adb006cce9db9e6ab9b0a17294bfcfa44 Signed-off-by: Thomas Kosiewski --- go/cmd/generate/internal/emit/dispatch.go | 4 +- .../internal/emit/dispatch_helpers.go | 2 +- go/cmd/generate/internal/emit/doc.go | 4 + go/cmd/generate/internal/emit/types.go | 151 +++++++++--------- go/cmd/generate/internal/ir/ir.go | 7 +- go/cmd/generate/internal/load/load.go | 2 + go/cmd/generate/internal/util/util.go | 2 + go/example/agent/main.go | 2 +- go/example/claude-code/main.go | 2 +- go/example/client/main.go | 2 +- go/example/gemini/main.go | 2 +- go/types_gen.go | 45 ------ 12 files changed, 96 insertions(+), 129 deletions(-) create mode 100644 go/cmd/generate/internal/emit/doc.go diff --git a/go/cmd/generate/internal/emit/dispatch.go b/go/cmd/generate/internal/emit/dispatch.go index 7b331fe..9c06ca9 100644 --- a/go/cmd/generate/internal/emit/dispatch.go +++ b/go/cmd/generate/internal/emit/dispatch.go @@ -79,7 +79,7 @@ func WriteDispatchJen(outDir string, schema *load.Schema, meta *load.Meta) error } else if ir.IsNullResponse(schema.Defs[respName]) { caseBody = append(caseBody, jCallRequestNoResp(recv, methodName)...) } else { - caseBody = append(caseBody, jCallRequestWithResp(recv, methodName, respName)...) + caseBody = append(caseBody, jCallRequestWithResp(recv, methodName)...) } } if len(caseBody) > 0 { @@ -186,7 +186,7 @@ func WriteDispatchJen(outDir string, schema *load.Schema, meta *load.Meta) error if ir.IsNullResponse(schema.Defs[respName]) { body = append(body, jCallRequestNoResp(recv, methodName)...) } else { - body = append(body, jCallRequestWithResp(recv, methodName, respName)...) + body = append(body, jCallRequestWithResp(recv, methodName)...) } } if len(body) > 0 { diff --git a/go/cmd/generate/internal/emit/dispatch_helpers.go b/go/cmd/generate/internal/emit/dispatch_helpers.go index 3167163..2a58854 100644 --- a/go/cmd/generate/internal/emit/dispatch_helpers.go +++ b/go/cmd/generate/internal/emit/dispatch_helpers.go @@ -62,7 +62,7 @@ func jCallRequestNoResp(recv, methodName string) []Code { } } -func jCallRequestWithResp(recv, methodName, respType string) []Code { +func jCallRequestWithResp(recv, methodName string) []Code { return []Code{ List(Id("resp"), Id("err")).Op(":=").Id(recv).Dot(methodName).Call(Id("ctx"), Id("p")), If(Id("err").Op("!=").Nil()).Block(jRetToReqErr()), diff --git a/go/cmd/generate/internal/emit/doc.go b/go/cmd/generate/internal/emit/doc.go new file mode 100644 index 0000000..b68cdf3 --- /dev/null +++ b/go/cmd/generate/internal/emit/doc.go @@ -0,0 +1,4 @@ +// Package emit contains helpers to generate the Go SDK from the +// intermediate representation produced by the ACP generator. It +// encapsulates code emission for types, helpers, and dispatch logic. +package emit diff --git a/go/cmd/generate/internal/emit/types.go b/go/cmd/generate/internal/emit/types.go index 4591969..39855c8 100644 --- a/go/cmd/generate/internal/emit/types.go +++ b/go/cmd/generate/internal/emit/types.go @@ -88,7 +88,7 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { if _, ok := req[pk]; !ok { // Default: omit if empty, except for specific always-present fields // Ensure InitializeResponse.authMethods is always encoded (even when empty) - if !(name == "InitializeResponse" && pk == "authMethods") { + if name != "InitializeResponse" || pk != "authMethods" { tag = pk + ",omitempty" } } @@ -583,86 +583,87 @@ func emitUnion(f *File, name string, defs []*load.Definition, exactlyOne bool) { // Null-only variant encodes to JSON null if vi.isNull { gg.Return(Qual("encoding/json", "Marshal").Call(Nil())) - } - // Marshal variant to map for discriminant injection and shaping - gg.Var().Id("m").Map(String()).Any() - gg.List(Id("_b"), Id("_e")).Op(":=").Qual("encoding/json", "Marshal").Call(Op("*").Id("u").Dot(vi.fieldName)) - gg.If(Id("_e").Op("!=").Nil()).Block(Return(Index().Byte().Values(), Id("_e"))) - gg.If(Qual("encoding/json", "Unmarshal").Call(Id("_b"), Op("&").Id("m")).Op("!=").Nil()).Block(Return(Index().Byte().Values(), Qual("errors", "New").Call(Lit("invalid variant payload")))) - // Inject const discriminants - if len(vi.constPairs) > 0 { - for _, kv := range vi.constPairs { - gg.Id("m").Index(Lit(kv[0])).Op("=").Lit(kv[1]) + } else { + // Marshal variant to map for discriminant injection and shaping + gg.Var().Id("m").Map(String()).Any() + gg.List(Id("_b"), Id("_e")).Op(":=").Qual("encoding/json", "Marshal").Call(Op("*").Id("u").Dot(vi.fieldName)) + gg.If(Id("_e").Op("!=").Nil()).Block(Return(Index().Byte().Values(), Id("_e"))) + gg.If(Qual("encoding/json", "Unmarshal").Call(Id("_b"), Op("&").Id("m")).Op("!=").Nil()).Block(Return(Index().Byte().Values(), Qual("errors", "New").Call(Lit("invalid variant payload")))) + // Inject const discriminants + if len(vi.constPairs) > 0 { + for _, kv := range vi.constPairs { + gg.Id("m").Index(Lit(kv[0])).Op("=").Lit(kv[1]) + } } - } - // Special shaping for ContentBlock variants to preserve exact wire JSON - if name == "ContentBlock" { - switch vi.discValue { - case "text": - gg.Block( - Var().Id("nm").Map(String()).Any(), - Id("nm").Op("=").Make(Map(String()).Any()), - Id("nm").Index(Lit("type")).Op("=").Lit("text"), - Id("nm").Index(Lit("text")).Op("=").Id("m").Index(Lit("text")), - Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), - ) - case "image": - gg.Block( - Var().Id("nm").Map(String()).Any(), - Id("nm").Op("=").Make(Map(String()).Any()), - Id("nm").Index(Lit("type")).Op("=").Lit("image"), - Id("nm").Index(Lit("data")).Op("=").Id("m").Index(Lit("data")), - Id("nm").Index(Lit("mimeType")).Op("=").Id("m").Index(Lit("mimeType")), - // Only include uri if present; do not emit null - If(List(Id("_v"), Id("_ok")).Op(":=").Id("m").Index(Lit("uri")), Id("_ok")).Block( - Id("nm").Index(Lit("uri")).Op("=").Id("_v"), - ), - Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), - ) - case "audio": - gg.Block( - Var().Id("nm").Map(String()).Any(), - Id("nm").Op("=").Make(Map(String()).Any()), - Id("nm").Index(Lit("type")).Op("=").Lit("audio"), - Id("nm").Index(Lit("data")).Op("=").Id("m").Index(Lit("data")), - Id("nm").Index(Lit("mimeType")).Op("=").Id("m").Index(Lit("mimeType")), - Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), - ) - case "resource_link": - gg.BlockFunc(func(b *Group) { - b.Var().Id("nm").Map(String()).Any() - b.Id("nm").Op("=").Make(Map(String()).Any()) - b.Id("nm").Index(Lit("type")).Op("=").Lit("resource_link") - b.Id("nm").Index(Lit("name")).Op("=").Id("m").Index(Lit("name")) - b.Id("nm").Index(Lit("uri")).Op("=").Id("m").Index(Lit("uri")) - // Only include optional keys if present - b.If(List(Id("v1"), Id("ok1")).Op(":=").Id("m").Index(Lit("description")), Id("ok1")).Block( - Id("nm").Index(Lit("description")).Op("=").Id("v1"), + // Special shaping for ContentBlock variants to preserve exact wire JSON + if name == "ContentBlock" { + switch vi.discValue { + case "text": + gg.Block( + Var().Id("nm").Map(String()).Any(), + Id("nm").Op("=").Make(Map(String()).Any()), + Id("nm").Index(Lit("type")).Op("=").Lit("text"), + Id("nm").Index(Lit("text")).Op("=").Id("m").Index(Lit("text")), + Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), ) - b.If(List(Id("v2"), Id("ok2")).Op(":=").Id("m").Index(Lit("mimeType")), Id("ok2")).Block( - Id("nm").Index(Lit("mimeType")).Op("=").Id("v2"), + case "image": + gg.Block( + Var().Id("nm").Map(String()).Any(), + Id("nm").Op("=").Make(Map(String()).Any()), + Id("nm").Index(Lit("type")).Op("=").Lit("image"), + Id("nm").Index(Lit("data")).Op("=").Id("m").Index(Lit("data")), + Id("nm").Index(Lit("mimeType")).Op("=").Id("m").Index(Lit("mimeType")), + // Only include uri if present; do not emit null + If(List(Id("_v"), Id("_ok")).Op(":=").Id("m").Index(Lit("uri")), Id("_ok")).Block( + Id("nm").Index(Lit("uri")).Op("=").Id("_v"), + ), + Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), ) - b.If(List(Id("v3"), Id("ok3")).Op(":=").Id("m").Index(Lit("size")), Id("ok3")).Block( - Id("nm").Index(Lit("size")).Op("=").Id("v3"), + case "audio": + gg.Block( + Var().Id("nm").Map(String()).Any(), + Id("nm").Op("=").Make(Map(String()).Any()), + Id("nm").Index(Lit("type")).Op("=").Lit("audio"), + Id("nm").Index(Lit("data")).Op("=").Id("m").Index(Lit("data")), + Id("nm").Index(Lit("mimeType")).Op("=").Id("m").Index(Lit("mimeType")), + Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), ) - b.If(List(Id("v4"), Id("ok4")).Op(":=").Id("m").Index(Lit("title")), Id("ok4")).Block( - Id("nm").Index(Lit("title")).Op("=").Id("v4"), + case "resource_link": + gg.BlockFunc(func(b *Group) { + b.Var().Id("nm").Map(String()).Any() + b.Id("nm").Op("=").Make(Map(String()).Any()) + b.Id("nm").Index(Lit("type")).Op("=").Lit("resource_link") + b.Id("nm").Index(Lit("name")).Op("=").Id("m").Index(Lit("name")) + b.Id("nm").Index(Lit("uri")).Op("=").Id("m").Index(Lit("uri")) + // Only include optional keys if present + b.If(List(Id("v1"), Id("ok1")).Op(":=").Id("m").Index(Lit("description")), Id("ok1")).Block( + Id("nm").Index(Lit("description")).Op("=").Id("v1"), + ) + b.If(List(Id("v2"), Id("ok2")).Op(":=").Id("m").Index(Lit("mimeType")), Id("ok2")).Block( + Id("nm").Index(Lit("mimeType")).Op("=").Id("v2"), + ) + b.If(List(Id("v3"), Id("ok3")).Op(":=").Id("m").Index(Lit("size")), Id("ok3")).Block( + Id("nm").Index(Lit("size")).Op("=").Id("v3"), + ) + b.If(List(Id("v4"), Id("ok4")).Op(":=").Id("m").Index(Lit("title")), Id("ok4")).Block( + Id("nm").Index(Lit("title")).Op("=").Id("v4"), + ) + b.Return(Qual("encoding/json", "Marshal").Call(Id("nm"))) + }) + case "resource": + gg.Block( + Var().Id("nm").Map(String()).Any(), + Id("nm").Op("=").Make(Map(String()).Any()), + Id("nm").Index(Lit("type")).Op("=").Lit("resource"), + Id("nm").Index(Lit("resource")).Op("=").Id("m").Index(Lit("resource")), + Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), ) - b.Return(Qual("encoding/json", "Marshal").Call(Id("nm"))) - }) - case "resource": - gg.Block( - Var().Id("nm").Map(String()).Any(), - Id("nm").Op("=").Make(Map(String()).Any()), - Id("nm").Index(Lit("type")).Op("=").Lit("resource"), - Id("nm").Index(Lit("resource")).Op("=").Id("m").Index(Lit("resource")), - Return(Qual("encoding/json", "Marshal").Call(Id("nm"))), - ) + } + } + // default: remarshal possibly with injected discriminant + if name != "ContentBlock" { + gg.Return(Qual("encoding/json", "Marshal").Call(Id("m"))) } - } - // default: remarshal possibly with injected discriminant - if name != "ContentBlock" { - gg.Return(Qual("encoding/json", "Marshal").Call(Id("m"))) } }) } diff --git a/go/cmd/generate/internal/ir/ir.go b/go/cmd/generate/internal/ir/ir.go index bd62e43..2a603ec 100644 --- a/go/cmd/generate/internal/ir/ir.go +++ b/go/cmd/generate/internal/ir/ir.go @@ -1,3 +1,6 @@ +// Package ir defines the intermediate representation used by the Go +// code generator. It organizes methods, bindings, and schema-derived +// types so the emit package can produce helpers, types, and dispatch code. package ir import ( @@ -144,14 +147,14 @@ func BuildMethodGroups(schema *load.Schema, meta *load.Meta) Groups { } // Post-process bindings and docs-ignore for _, mi := range groups { - mi.Binding = classifyBinding(schema, meta, mi) + mi.Binding = classifyBinding(schema, mi) mi.DocsIgnored = isDocsIgnoredMethod(schema, mi) } return groups } // classifyBinding determines interface binding for each method. -func classifyBinding(schema *load.Schema, meta *load.Meta, mi *MethodInfo) MethodBinding { +func classifyBinding(schema *load.Schema, mi *MethodInfo) MethodBinding { if mi == nil { return BindUnknown } diff --git a/go/cmd/generate/internal/load/load.go b/go/cmd/generate/internal/load/load.go index 27cbd16..e9a044f 100644 --- a/go/cmd/generate/internal/load/load.go +++ b/go/cmd/generate/internal/load/load.go @@ -1,3 +1,5 @@ +// Package load provides utilities to read the ACP JSON schema and +// accompanying metadata into minimal structures used by the generator. package load import ( diff --git a/go/cmd/generate/internal/util/util.go b/go/cmd/generate/internal/util/util.go index 8912623..0ac7871 100644 --- a/go/cmd/generate/internal/util/util.go +++ b/go/cmd/generate/internal/util/util.go @@ -1,3 +1,5 @@ +// Package util contains small string and identifier helpers used by the +// code generator for formatting names and comments. package util import ( diff --git a/go/example/agent/main.go b/go/example/agent/main.go index aa7cb51..44a5fcb 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -61,7 +61,7 @@ func (a *exampleAgent) Cancel(ctx context.Context, params acp.CancelNotification return nil } -func (a *exampleAgent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.PromptResponse, error) { +func (a *exampleAgent) Prompt(_ context.Context, params acp.PromptRequest) (acp.PromptResponse, error) { sid := string(params.SessionId) s, ok := a.sessions[sid] if !ok { diff --git a/go/example/claude-code/main.go b/go/example/claude-code/main.go index b0bb3e2..fcdbbda 100644 --- a/go/example/claude-code/main.go +++ b/go/example/claude-code/main.go @@ -55,7 +55,7 @@ func (c *replClient) RequestPermission(ctx context.Context, params acp.RequestPe continue } idx := -1 - fmt.Sscanf(line, "%d", &idx) + _, _ = fmt.Sscanf(line, "%d", &idx) idx = idx - 1 if idx >= 0 && idx < len(params.Options) { return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: params.Options[idx].OptionId}}}, nil diff --git a/go/example/client/main.go b/go/example/client/main.go index eb59777..44fb982 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -39,7 +39,7 @@ func (e *exampleClient) RequestPermission(ctx context.Context, params acp.Reques continue } idx := -1 - fmt.Sscanf(line, "%d", &idx) + _, _ = fmt.Sscanf(line, "%d", &idx) idx = idx - 1 if idx >= 0 && idx < len(params.Options) { return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: params.Options[idx].OptionId}}}, nil diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index 64154d1..694dea1 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -58,7 +58,7 @@ func (c *replClient) RequestPermission(ctx context.Context, params acp.RequestPe continue } idx := -1 - fmt.Sscanf(line, "%d", &idx) + _, _ = fmt.Sscanf(line, "%d", &idx) idx = idx - 1 if idx >= 0 && idx < len(params.Options) { return acp.RequestPermissionResponse{Outcome: acp.RequestPermissionOutcome{Selected: &acp.RequestPermissionOutcomeSelected{OptionId: params.Options[idx].OptionId}}}, nil diff --git a/go/types_gen.go b/go/types_gen.go index 8ede4b8..d15416f 100644 --- a/go/types_gen.go +++ b/go/types_gen.go @@ -292,15 +292,6 @@ func (u AgentResponse) MarshalJSON() ([]byte, error) { } if u.AuthenticateResponse != nil { return json.Marshal(nil) - var m map[string]any - _b, _e := json.Marshal(*u.AuthenticateResponse) - if _e != nil { - return []byte{}, _e - } - if json.Unmarshal(_b, &m) != nil { - return []byte{}, errors.New("invalid variant payload") - } - return json.Marshal(m) } if u.NewSessionResponse != nil { var m map[string]any @@ -315,15 +306,6 @@ func (u AgentResponse) MarshalJSON() ([]byte, error) { } if u.LoadSessionResponse != nil { return json.Marshal(nil) - var m map[string]any - _b, _e := json.Marshal(*u.LoadSessionResponse) - if _e != nil { - return []byte{}, _e - } - if json.Unmarshal(_b, &m) != nil { - return []byte{}, errors.New("invalid variant payload") - } - return json.Marshal(m) } if u.PromptResponse != nil { var m map[string]any @@ -697,15 +679,6 @@ func (u *ClientResponse) UnmarshalJSON(b []byte) error { func (u ClientResponse) MarshalJSON() ([]byte, error) { if u.WriteTextFileResponse != nil { return json.Marshal(nil) - var m map[string]any - _b, _e := json.Marshal(*u.WriteTextFileResponse) - if _e != nil { - return []byte{}, _e - } - if json.Unmarshal(_b, &m) != nil { - return []byte{}, errors.New("invalid variant payload") - } - return json.Marshal(m) } if u.ReadTextFileResponse != nil { var m map[string]any @@ -753,15 +726,6 @@ func (u ClientResponse) MarshalJSON() ([]byte, error) { } if u.ReleaseTerminalResponse != nil { return json.Marshal(nil) - var m map[string]any - _b, _e := json.Marshal(*u.ReleaseTerminalResponse) - if _e != nil { - return []byte{}, _e - } - if json.Unmarshal(_b, &m) != nil { - return []byte{}, errors.New("invalid variant payload") - } - return json.Marshal(m) } if u.WaitForTerminalExitResponse != nil { var m map[string]any @@ -776,15 +740,6 @@ func (u ClientResponse) MarshalJSON() ([]byte, error) { } if u.KillTerminalResponse != nil { return json.Marshal(nil) - var m map[string]any - _b, _e := json.Marshal(*u.KillTerminalResponse) - if _e != nil { - return []byte{}, _e - } - if json.Unmarshal(_b, &m) != nil { - return []byte{}, errors.New("invalid variant payload") - } - return json.Marshal(m) } return []byte{}, nil } From 3bdf7c13b1cae2a6b3d739561272e8aa22b5f31b Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Tue, 2 Sep 2025 17:58:23 +0200 Subject: [PATCH 16/22] refactor: extract helpers to static file and optimize code generation Change-Id: I410e8a8452949641864ca4e2e876ff2df6d71eee Signed-off-by: Thomas Kosiewski --- go/cmd/generate/internal/emit/helpers.go | 253 +----------- go/helpers.go | 259 ++++++++++++ go/helpers_gen.go | 376 ------------------ go/json_parity_test.go | 48 +-- .../session_update_tool_call_minimal.json | 6 - ...ssion_update_tool_call_update_minimal.json | 5 - 6 files changed, 267 insertions(+), 680 deletions(-) create mode 100644 go/helpers.go delete mode 100644 go/testdata/json_golden/session_update_tool_call_minimal.json delete mode 100644 go/testdata/json_golden/session_update_tool_call_update_minimal.json diff --git a/go/cmd/generate/internal/emit/helpers.go b/go/cmd/generate/internal/emit/helpers.go index 3b6e5dc..2f5ce5a 100644 --- a/go/cmd/generate/internal/emit/helpers.go +++ b/go/cmd/generate/internal/emit/helpers.go @@ -18,204 +18,6 @@ func WriteHelpersJen(outDir string, schema *load.Schema, _ *load.Meta) error { f := NewFile("acp") f.HeaderComment("Code generated by acp-go-generator; DO NOT EDIT.") - // Content helpers - f.Comment("TextBlock constructs a text content block.") - f.Func().Id("TextBlock").Params(Id("text").String()).Id("ContentBlock").Block( - Return(Id("ContentBlock").Values(Dict{ - Id("Text"): Op("&").Id("ContentBlockText").Values(Dict{Id("Type"): Lit("text"), Id("Text"): Id("text")}), - })), - ) - f.Line() - - f.Comment("ImageBlock constructs an inline image content block with base64-encoded data.") - f.Func().Id("ImageBlock").Params(Id("data").String(), Id("mimeType").String()).Id("ContentBlock").Block( - Return(Id("ContentBlock").Values(Dict{ - Id("Image"): Op("&").Id("ContentBlockImage").Values(Dict{Id("Type"): Lit("image"), Id("Data"): Id("data"), Id("MimeType"): Id("mimeType")}), - })), - ) - f.Line() - - f.Comment("AudioBlock constructs an inline audio content block with base64-encoded data.") - f.Func().Id("AudioBlock").Params(Id("data").String(), Id("mimeType").String()).Id("ContentBlock").Block( - Return(Id("ContentBlock").Values(Dict{ - Id("Audio"): Op("&").Id("ContentBlockAudio").Values(Dict{Id("Type"): Lit("audio"), Id("Data"): Id("data"), Id("MimeType"): Id("mimeType")}), - })), - ) - f.Line() - - f.Comment("ResourceLinkBlock constructs a resource_link content block with a name and URI.") - f.Func().Id("ResourceLinkBlock").Params(Id("name").String(), Id("uri").String()).Id("ContentBlock").Block( - Return(Id("ContentBlock").Values(Dict{ - Id("ResourceLink"): Op("&").Id("ContentBlockResourceLink").Values(Dict{Id("Type"): Lit("resource_link"), Id("Name"): Id("name"), Id("Uri"): Id("uri")}), - })), - ) - f.Line() - - f.Comment("ResourceBlock wraps an embedded resource as a content block.") - f.Func().Id("ResourceBlock").Params(Id("res").Id("EmbeddedResource")).Id("ContentBlock").Block( - Var().Id("r").Id("EmbeddedResource").Op("=").Id("res"), - Return(Id("ContentBlock").Values(Dict{ - Id("Resource"): Op("&").Id("ContentBlockResource").Values(Dict{Id("Type"): Lit("resource"), Id("Resource"): Id("r").Dot("Resource")}), - })), - ) - f.Line() - - // ToolCall content helpers - f.Comment("ToolContent wraps a content block as tool-call content.") - f.Func().Id("ToolContent").Params(Id("block").Id("ContentBlock")).Id("ToolCallContent").Block( - Return(Id("ToolCallContent").Values(Dict{ - Id("Content"): Op("&").Id("ToolCallContentContent").Values(Dict{Id("Content"): Id("block"), Id("Type"): Lit("content")}), - })), - ) - f.Line() - - f.Comment("ToolDiffContent constructs a diff tool-call content. If oldText is omitted, the field is left empty.") - f.Func().Id("ToolDiffContent").Params(Id("path").String(), Id("newText").String(), Id("oldText").Op("...").String()).Id("ToolCallContent").Block( - Var().Id("o").Op("*").String(), - If(Id("len").Call(Id("oldText")).Op(">").Lit(0)).Block( - Id("o").Op("=").Op("&").Id("oldText").Index(Lit(0)), - ), - Return(Id("ToolCallContent").Values(Dict{ - Id("Diff"): Op("&").Id("ToolCallContentDiff").Values(Dict{Id("Path"): Id("path"), Id("NewText"): Id("newText"), Id("OldText"): Id("o"), Id("Type"): Lit("diff")}), - })), - ) - f.Line() - - f.Comment("ToolTerminalRef constructs a terminal reference tool-call content.") - f.Func().Id("ToolTerminalRef").Params(Id("terminalId").String()).Id("ToolCallContent").Block( - Return(Id("ToolCallContent").Values(Dict{ - Id("Terminal"): Op("&").Id("ToolCallContentTerminal").Values(Dict{Id("TerminalId"): Id("terminalId"), Id("Type"): Lit("terminal")}), - })), - ) - f.Line() - - // Generic pointer helper - f.Comment("Ptr returns a pointer to v.") - f.Func().Id("Ptr").Types(Id("T").Any()).Params(Id("v").Id("T")).Op("*").Id("T").Block( - Return(Op("&").Id("v")), - ) - - // SessionUpdate helpers (friendly aliases) - f.Line() - f.Comment("UpdateUserMessage constructs a user_message_chunk update with the given content.") - f.Func().Id("UpdateUserMessage").Params(Id("content").Id("ContentBlock")).Id("SessionUpdate").Block( - Return(Id("SessionUpdate").Values(Dict{Id("UserMessageChunk"): Op("&").Id("SessionUpdateUserMessageChunk").Values(Dict{Id("Content"): Id("content")})})), - ) - f.Comment("UpdateUserMessageText constructs a user_message_chunk update from text.") - f.Func().Id("UpdateUserMessageText").Params(Id("text").String()).Id("SessionUpdate").Block( - Return(Id("UpdateUserMessage").Call(Id("TextBlock").Call(Id("text")))), - ) - - f.Comment("UpdateAgentMessage constructs an agent_message_chunk update with the given content.") - f.Func().Id("UpdateAgentMessage").Params(Id("content").Id("ContentBlock")).Id("SessionUpdate").Block( - Return(Id("SessionUpdate").Values(Dict{Id("AgentMessageChunk"): Op("&").Id("SessionUpdateAgentMessageChunk").Values(Dict{Id("Content"): Id("content")})})), - ) - f.Comment("UpdateAgentMessageText constructs an agent_message_chunk update from text.") - f.Func().Id("UpdateAgentMessageText").Params(Id("text").String()).Id("SessionUpdate").Block( - Return(Id("UpdateAgentMessage").Call(Id("TextBlock").Call(Id("text")))), - ) - - f.Comment("UpdateAgentThought constructs an agent_thought_chunk update with the given content.") - f.Func().Id("UpdateAgentThought").Params(Id("content").Id("ContentBlock")).Id("SessionUpdate").Block( - Return(Id("SessionUpdate").Values(Dict{Id("AgentThoughtChunk"): Op("&").Id("SessionUpdateAgentThoughtChunk").Values(Dict{Id("Content"): Id("content")})})), - ) - f.Comment("UpdateAgentThoughtText constructs an agent_thought_chunk update from text.") - f.Func().Id("UpdateAgentThoughtText").Params(Id("text").String()).Id("SessionUpdate").Block( - Return(Id("UpdateAgentThought").Call(Id("TextBlock").Call(Id("text")))), - ) - - f.Comment("UpdatePlan constructs a plan update with the provided entries.") - f.Func().Id("UpdatePlan").Params(Id("entries").Op("...").Id("PlanEntry")).Id("SessionUpdate").Block( - Return(Id("SessionUpdate").Values(Dict{Id("Plan"): Op("&").Id("SessionUpdatePlan").Values(Dict{Id("Entries"): Id("entries")})})), - ) - - // Tool call start helpers with functional options (friendly aliases) - f.Line() - f.Type().Id("ToolCallStartOpt").Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")) - f.Comment("StartToolCall constructs a tool_call update with required fields and applies optional modifiers.") - f.Func().Id("StartToolCall").Params(Id("id").Id("ToolCallId"), Id("title").String(), Id("opts").Op("...").Id("ToolCallStartOpt")).Id("SessionUpdate").Block( - Id("tc").Op(":=").Id("SessionUpdateToolCall").Values(Dict{Id("ToolCallId"): Id("id"), Id("Title"): Id("title")}), - For(List(Id("_"), Id("opt")).Op(":=").Range().Id("opts")).Block(Id("opt").Call(Op("&").Id("tc"))), - Return(Id("SessionUpdate").Values(Dict{Id("ToolCall"): Op("&").Id("tc")})), - ) - f.Comment("WithStartKind sets the kind for a tool_call start update.") - f.Func().Id("WithStartKind").Params(Id("k").Id("ToolKind")).Id("ToolCallStartOpt").Block( - Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("Kind").Op("=").Id("k"))), - ) - f.Comment("WithStartStatus sets the status for a tool_call start update.") - f.Func().Id("WithStartStatus").Params(Id("s").Id("ToolCallStatus")).Id("ToolCallStartOpt").Block( - Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("Status").Op("=").Id("s"))), - ) - f.Comment("WithStartContent sets the initial content for a tool_call start update.") - f.Func().Id("WithStartContent").Params(Id("c").Index().Id("ToolCallContent")).Id("ToolCallStartOpt").Block( - Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("Content").Op("=").Id("c"))), - ) - f.Comment("WithStartLocations sets file locations and, if a single path is provided and rawInput is empty, mirrors it as rawInput.path.") - f.Func().Id("WithStartLocations").Params(Id("l").Index().Id("ToolCallLocation")).Id("ToolCallStartOpt").Block( - Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).BlockFunc(func(g *Group) { - g.Id("tc").Dot("Locations").Op("=").Id("l") - g.If(Id("len").Call(Id("l")).Op("==").Lit(1).Op("&&").Id("l").Index(Lit(0)).Dot("Path").Op("!=").Lit("")).BlockFunc(func(h *Group) { - // initialize rawInput if nil - h.If(Id("tc").Dot("RawInput").Op("==").Nil()).Block( - Id("tc").Dot("RawInput").Op("=").Map(String()).Any().Values(Dict{Lit("path"): Id("l").Index(Lit(0)).Dot("Path")}), - ).Else().BlockFunc(func(b *Group) { - b.List(Id("m"), Id("ok")).Op(":=").Id("tc").Dot("RawInput").Assert(Map(String()).Any()) - b.If(Id("ok")).Block( - If(List(Id("_"), Id("exists")).Op(":=").Id("m").Index(Lit("path")), Op("!").Id("exists")).Block( - Id("m").Index(Lit("path")).Op("=").Id("l").Index(Lit(0)).Dot("Path"), - ), - ) - }) - }) - })), - ) - f.Comment("WithStartRawInput sets rawInput for a tool_call start update.") - f.Func().Id("WithStartRawInput").Params(Id("v").Any()).Id("ToolCallStartOpt").Block( - Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("RawInput").Op("=").Id("v"))), - ) - f.Comment("WithStartRawOutput sets rawOutput for a tool_call start update.") - f.Func().Id("WithStartRawOutput").Params(Id("v").Any()).Id("ToolCallStartOpt").Block( - Return(Func().Params(Id("tc").Op("*").Id("SessionUpdateToolCall")).Block(Id("tc").Dot("RawOutput").Op("=").Id("v"))), - ) - - // Tool call update helpers with functional options (pointer fields; friendly aliases) - f.Line() - f.Type().Id("ToolCallUpdateOpt").Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")) - f.Comment("UpdateToolCall constructs a tool_call_update with the given ID and applies optional modifiers.") - f.Func().Id("UpdateToolCall").Params(Id("id").Id("ToolCallId"), Id("opts").Op("...").Id("ToolCallUpdateOpt")).Id("SessionUpdate").Block( - Id("tu").Op(":=").Id("SessionUpdateToolCallUpdate").Values(Dict{Id("ToolCallId"): Id("id")}), - For(List(Id("_"), Id("opt")).Op(":=").Range().Id("opts")).Block(Id("opt").Call(Op("&").Id("tu"))), - Return(Id("SessionUpdate").Values(Dict{Id("ToolCallUpdate"): Op("&").Id("tu")})), - ) - f.Comment("WithUpdateTitle sets the title for a tool_call_update.") - f.Func().Id("WithUpdateTitle").Params(Id("t").String()).Id("ToolCallUpdateOpt").Block( - Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Title").Op("=").Id("Ptr").Call(Id("t")))), - ) - f.Comment("WithUpdateKind sets the kind for a tool_call_update.") - f.Func().Id("WithUpdateKind").Params(Id("k").Id("ToolKind")).Id("ToolCallUpdateOpt").Block( - Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Kind").Op("=").Id("Ptr").Call(Id("k")))), - ) - f.Comment("WithUpdateStatus sets the status for a tool_call_update.") - f.Func().Id("WithUpdateStatus").Params(Id("s").Id("ToolCallStatus")).Id("ToolCallUpdateOpt").Block( - Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Status").Op("=").Id("Ptr").Call(Id("s")))), - ) - f.Comment("WithUpdateContent replaces the content collection for a tool_call_update.") - f.Func().Id("WithUpdateContent").Params(Id("c").Index().Id("ToolCallContent")).Id("ToolCallUpdateOpt").Block( - Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Content").Op("=").Id("c"))), - ) - f.Comment("WithUpdateLocations replaces the locations collection for a tool_call_update.") - f.Func().Id("WithUpdateLocations").Params(Id("l").Index().Id("ToolCallLocation")).Id("ToolCallUpdateOpt").Block( - Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("Locations").Op("=").Id("l"))), - ) - f.Comment("WithUpdateRawInput sets rawInput for a tool_call_update.") - f.Func().Id("WithUpdateRawInput").Params(Id("v").Any()).Id("ToolCallUpdateOpt").Block( - Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("RawInput").Op("=").Id("v"))), - ) - f.Comment("WithUpdateRawOutput sets rawOutput for a tool_call_update.") - f.Func().Id("WithUpdateRawOutput").Params(Id("v").Any()).Id("ToolCallUpdateOpt").Block( - Return(Func().Params(Id("tu").Op("*").Id("SessionUpdateToolCallUpdate")).Block(Id("tu").Dot("RawOutput").Op("=").Id("v"))), - ) - // Schema-driven generic helpers: New(required fields only) // Iterate definitions deterministically keys := make([]string, 0, len(schema.Defs)) @@ -232,6 +34,12 @@ func WriteHelpersJen(outDir string, schema *load.Schema, _ *load.Meta) error { if isStringConstUnion(def) { continue } + // Skip generating New... helpers for unions that have stable, static helpers + // implemented in go/helpers.go. + switch name { + case "ContentBlock", "ToolCallContent", "SessionUpdate": + continue + } // Build variant info similarly to types emitter type vinfo struct { fieldName string @@ -329,58 +137,9 @@ func WriteHelpersJen(outDir string, schema *load.Schema, _ *load.Meta) error { } } - // Friendly aliases: opinionated tool call starters for common cases - // StartReadToolCall: sets kind=read, status=pending, locations=[{path}], rawInput={path} - f.Comment("StartReadToolCall constructs a 'tool_call' update for reading a file: kind=read, status=pending, locations=[{path}], rawInput={path}.") - f.Func().Id("StartReadToolCall").Params( - Id("id").Id("ToolCallId"), - Id("title").String(), - Id("path").String(), - Id("opts").Op("...").Id("ToolCallStartOpt"), - ).Id("SessionUpdate").Block( - Id("base").Op(":=").Index().Id("ToolCallStartOpt").Values( - Id("WithStartKind").Call(Id("ToolKindRead")), - Id("WithStartStatus").Call(Id("ToolCallStatusPending")), - Id("WithStartLocations").Call( - Index().Id("ToolCallLocation").Values( - Id("ToolCallLocation").Values(Dict{Id("Path"): Id("path")}), - ), - ), - Id("WithStartRawInput").Call(Map(String()).Any().Values(Dict{Lit("path"): Id("path")})), - ), - Id("args").Op(":=").Id("append").Call(Id("base"), Id("opts").Op("...")), - Return(Id("StartToolCall").Call(Id("id"), Id("title"), Id("args").Op("..."))), - ) - f.Line() - // StartEditToolCall: sets kind=edit, status=pending, locations=[{path}], rawInput={path, content} - f.Comment("StartEditToolCall constructs a 'tool_call' update for editing content: kind=edit, status=pending, locations=[{path}], rawInput={path, content}.") - f.Func().Id("StartEditToolCall").Params( - Id("id").Id("ToolCallId"), - Id("title").String(), - Id("path").String(), - Id("content").Any(), - Id("opts").Op("...").Id("ToolCallStartOpt"), - ).Id("SessionUpdate").Block( - Id("base").Op(":=").Index().Id("ToolCallStartOpt").Values( - Id("WithStartKind").Call(Id("ToolKindEdit")), - Id("WithStartStatus").Call(Id("ToolCallStatusPending")), - Id("WithStartLocations").Call( - Index().Id("ToolCallLocation").Values( - Id("ToolCallLocation").Values(Dict{Id("Path"): Id("path")}), - ), - ), - Id("WithStartRawInput").Call(Map(String()).Any().Values(Dict{Lit("path"): Id("path"), Lit("content"): Id("content")})), - ), - Id("args").Op(":=").Id("append").Call(Id("base"), Id("opts").Op("...")), - Return(Id("StartToolCall").Call(Id("id"), Id("title"), Id("args").Op("..."))), - ) - f.Line() - var buf bytes.Buffer if err := f.Render(&buf); err != nil { return err } return os.WriteFile(filepath.Join(outDir, "helpers_gen.go"), buf.Bytes(), 0o644) } - -// Note: isStringConstUnion exists in types emitter; we reference that file-level function diff --git a/go/helpers.go b/go/helpers.go new file mode 100644 index 0000000..cedfea2 --- /dev/null +++ b/go/helpers.go @@ -0,0 +1,259 @@ +package acp + +// TextBlock constructs a text content block. +func TextBlock(text string) ContentBlock { + return ContentBlock{Text: &ContentBlockText{ + Text: text, + Type: "text", + }} +} + +// ImageBlock constructs an inline image content block with base64-encoded data. +func ImageBlock(data string, mimeType string) ContentBlock { + return ContentBlock{Image: &ContentBlockImage{ + Data: data, + MimeType: mimeType, + Type: "image", + }} +} + +// AudioBlock constructs an inline audio content block with base64-encoded data. +func AudioBlock(data string, mimeType string) ContentBlock { + return ContentBlock{Audio: &ContentBlockAudio{ + Data: data, + MimeType: mimeType, + Type: "audio", + }} +} + +// ResourceLinkBlock constructs a resource_link content block with a name and URI. +func ResourceLinkBlock(name string, uri string) ContentBlock { + return ContentBlock{ResourceLink: &ContentBlockResourceLink{ + Name: name, + Type: "resource_link", + Uri: uri, + }} +} + +// ResourceBlock wraps an embedded resource as a content block. +func ResourceBlock(res EmbeddedResource) ContentBlock { + return ContentBlock{Resource: &ContentBlockResource{ + Resource: res.Resource, + Type: "resource", + }} +} + +// ToolContent wraps a content block as tool-call content. +func ToolContent(block ContentBlock) ToolCallContent { + return ToolCallContent{Content: &ToolCallContentContent{ + Content: block, + Type: "content", + }} +} + +// ToolDiffContent constructs a diff tool-call content. If oldText is omitted, the field is left empty. +func ToolDiffContent(path string, newText string, oldText ...string) ToolCallContent { + var o *string + if len(oldText) > 0 { + o = &oldText[0] + } + return ToolCallContent{Diff: &ToolCallContentDiff{ + NewText: newText, + OldText: o, + Path: path, + Type: "diff", + }} +} + +// ToolTerminalRef constructs a terminal reference tool-call content. +func ToolTerminalRef(terminalID string) ToolCallContent { + return ToolCallContent{Terminal: &ToolCallContentTerminal{ + TerminalId: terminalID, + Type: "terminal", + }} +} + +// Ptr returns a pointer to v. +func Ptr[T any](v T) *T { + return &v +} + +// UpdateUserMessage constructs a user_message_chunk update with the given content. +func UpdateUserMessage(content ContentBlock) SessionUpdate { + return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: content}} +} + +// UpdateUserMessageText constructs a user_message_chunk update from text. +func UpdateUserMessageText(text string) SessionUpdate { + return UpdateUserMessage(TextBlock(text)) +} + +// UpdateAgentMessage constructs an agent_message_chunk update with the given content. +func UpdateAgentMessage(content ContentBlock) SessionUpdate { + return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: content}} +} + +// UpdateAgentMessageText constructs an agent_message_chunk update from text. +func UpdateAgentMessageText(text string) SessionUpdate { + return UpdateAgentMessage(TextBlock(text)) +} + +// UpdateAgentThought constructs an agent_thought_chunk update with the given content. +func UpdateAgentThought(content ContentBlock) SessionUpdate { + return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{Content: content}} +} + +// UpdateAgentThoughtText constructs an agent_thought_chunk update from text. +func UpdateAgentThoughtText(text string) SessionUpdate { + return UpdateAgentThought(TextBlock(text)) +} + +// UpdatePlan constructs a plan update with the provided entries. +func UpdatePlan(entries ...PlanEntry) SessionUpdate { + return SessionUpdate{Plan: &SessionUpdatePlan{Entries: entries}} +} + +type ToolCallStartOpt func(tc *SessionUpdateToolCall) + +// StartToolCall constructs a tool_call update with required fields and applies optional modifiers. +func StartToolCall(id ToolCallId, title string, opts ...ToolCallStartOpt) SessionUpdate { + tc := SessionUpdateToolCall{ + Title: title, + ToolCallId: id, + } + for _, opt := range opts { + opt(&tc) + } + return SessionUpdate{ToolCall: &tc} +} + +// WithStartKind sets the kind for a tool_call start update. +func WithStartKind(k ToolKind) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.Kind = k + } +} + +// WithStartStatus sets the status for a tool_call start update. +func WithStartStatus(s ToolCallStatus) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.Status = s + } +} + +// WithStartContent sets the initial content for a tool_call start update. +func WithStartContent(c []ToolCallContent) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.Content = c + } +} + +// WithStartLocations sets file locations and, if a single path is provided and rawInput is empty, mirrors it as rawInput.path. +func WithStartLocations(l []ToolCallLocation) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.Locations = l + if len(l) == 1 && l[0].Path != "" { + if tc.RawInput == nil { + tc.RawInput = map[string]any{"path": l[0].Path} + } else { + m, ok := tc.RawInput.(map[string]any) + if ok { + if _, exists := m["path"]; !exists { + m["path"] = l[0].Path + } + } + } + } + } +} + +// WithStartRawInput sets rawInput for a tool_call start update. +func WithStartRawInput(v any) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.RawInput = v + } +} + +// WithStartRawOutput sets rawOutput for a tool_call start update. +func WithStartRawOutput(v any) ToolCallStartOpt { + return func(tc *SessionUpdateToolCall) { + tc.RawOutput = v + } +} + +type ToolCallUpdateOpt func(tu *SessionUpdateToolCallUpdate) + +// UpdateToolCall constructs a tool_call_update with the given ID and applies optional modifiers. +func UpdateToolCall(id ToolCallId, opts ...ToolCallUpdateOpt) SessionUpdate { + tu := SessionUpdateToolCallUpdate{ToolCallId: id} + for _, opt := range opts { + opt(&tu) + } + return SessionUpdate{ToolCallUpdate: &tu} +} + +// WithUpdateTitle sets the title for a tool_call_update. +func WithUpdateTitle(t string) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Title = Ptr(t) + } +} + +// WithUpdateKind sets the kind for a tool_call_update. +func WithUpdateKind(k ToolKind) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Kind = Ptr(k) + } +} + +// WithUpdateStatus sets the status for a tool_call_update. +func WithUpdateStatus(s ToolCallStatus) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Status = Ptr(s) + } +} + +// WithUpdateContent replaces the content collection for a tool_call_update. +func WithUpdateContent(c []ToolCallContent) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Content = c + } +} + +// WithUpdateLocations replaces the locations collection for a tool_call_update. +func WithUpdateLocations(l []ToolCallLocation) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.Locations = l + } +} + +// WithUpdateRawInput sets rawInput for a tool_call_update. +func WithUpdateRawInput(v any) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.RawInput = v + } +} + +// WithUpdateRawOutput sets rawOutput for a tool_call_update. +func WithUpdateRawOutput(v any) ToolCallUpdateOpt { + return func(tu *SessionUpdateToolCallUpdate) { + tu.RawOutput = v + } +} + +// StartReadToolCall constructs a 'tool_call' update for reading a file: kind=read, status=pending, locations=[{path}], rawInput={path}. +func StartReadToolCall(id ToolCallId, title string, path string, opts ...ToolCallStartOpt) SessionUpdate { + base := []ToolCallStartOpt{WithStartKind(ToolKindRead), WithStartStatus(ToolCallStatusPending), WithStartLocations([]ToolCallLocation{{Path: path}}), WithStartRawInput(map[string]any{"path": path})} + args := append(base, opts...) + return StartToolCall(id, title, args...) +} + +// StartEditToolCall constructs a 'tool_call' update for editing content: kind=edit, status=pending, locations=[{path}], rawInput={path, content}. +func StartEditToolCall(id ToolCallId, title string, path string, content any, opts ...ToolCallStartOpt) SessionUpdate { + base := []ToolCallStartOpt{WithStartKind(ToolKindEdit), WithStartStatus(ToolCallStatusPending), WithStartLocations([]ToolCallLocation{{Path: path}}), WithStartRawInput(map[string]any{ + "content": content, + "path": path, + })} + args := append(base, opts...) + return StartToolCall(id, title, args...) +} diff --git a/go/helpers_gen.go b/go/helpers_gen.go index c99f1c6..510e36a 100644 --- a/go/helpers_gen.go +++ b/go/helpers_gen.go @@ -2,291 +2,6 @@ package acp -// TextBlock constructs a text content block. -func TextBlock(text string) ContentBlock { - return ContentBlock{Text: &ContentBlockText{ - Text: text, - Type: "text", - }} -} - -// ImageBlock constructs an inline image content block with base64-encoded data. -func ImageBlock(data string, mimeType string) ContentBlock { - return ContentBlock{Image: &ContentBlockImage{ - Data: data, - MimeType: mimeType, - Type: "image", - }} -} - -// AudioBlock constructs an inline audio content block with base64-encoded data. -func AudioBlock(data string, mimeType string) ContentBlock { - return ContentBlock{Audio: &ContentBlockAudio{ - Data: data, - MimeType: mimeType, - Type: "audio", - }} -} - -// ResourceLinkBlock constructs a resource_link content block with a name and URI. -func ResourceLinkBlock(name string, uri string) ContentBlock { - return ContentBlock{ResourceLink: &ContentBlockResourceLink{ - Name: name, - Type: "resource_link", - Uri: uri, - }} -} - -// ResourceBlock wraps an embedded resource as a content block. -func ResourceBlock(res EmbeddedResource) ContentBlock { - var r EmbeddedResource = res - return ContentBlock{Resource: &ContentBlockResource{ - Resource: r.Resource, - Type: "resource", - }} -} - -// ToolContent wraps a content block as tool-call content. -func ToolContent(block ContentBlock) ToolCallContent { - return ToolCallContent{Content: &ToolCallContentContent{ - Content: block, - Type: "content", - }} -} - -// ToolDiffContent constructs a diff tool-call content. If oldText is omitted, the field is left empty. -func ToolDiffContent(path string, newText string, oldText ...string) ToolCallContent { - var o *string - if len(oldText) > 0 { - o = &oldText[0] - } - return ToolCallContent{Diff: &ToolCallContentDiff{ - NewText: newText, - OldText: o, - Path: path, - Type: "diff", - }} -} - -// ToolTerminalRef constructs a terminal reference tool-call content. -func ToolTerminalRef(terminalId string) ToolCallContent { - return ToolCallContent{Terminal: &ToolCallContentTerminal{ - TerminalId: terminalId, - Type: "terminal", - }} -} - -// Ptr returns a pointer to v. -func Ptr[T any](v T) *T { - return &v -} - -// UpdateUserMessage constructs a user_message_chunk update with the given content. -func UpdateUserMessage(content ContentBlock) SessionUpdate { - return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: content}} -} - -// UpdateUserMessageText constructs a user_message_chunk update from text. -func UpdateUserMessageText(text string) SessionUpdate { - return UpdateUserMessage(TextBlock(text)) -} - -// UpdateAgentMessage constructs an agent_message_chunk update with the given content. -func UpdateAgentMessage(content ContentBlock) SessionUpdate { - return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: content}} -} - -// UpdateAgentMessageText constructs an agent_message_chunk update from text. -func UpdateAgentMessageText(text string) SessionUpdate { - return UpdateAgentMessage(TextBlock(text)) -} - -// UpdateAgentThought constructs an agent_thought_chunk update with the given content. -func UpdateAgentThought(content ContentBlock) SessionUpdate { - return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{Content: content}} -} - -// UpdateAgentThoughtText constructs an agent_thought_chunk update from text. -func UpdateAgentThoughtText(text string) SessionUpdate { - return UpdateAgentThought(TextBlock(text)) -} - -// UpdatePlan constructs a plan update with the provided entries. -func UpdatePlan(entries ...PlanEntry) SessionUpdate { - return SessionUpdate{Plan: &SessionUpdatePlan{Entries: entries}} -} - -type ToolCallStartOpt func(tc *SessionUpdateToolCall) - -// StartToolCall constructs a tool_call update with required fields and applies optional modifiers. -func StartToolCall(id ToolCallId, title string, opts ...ToolCallStartOpt) SessionUpdate { - tc := SessionUpdateToolCall{ - Title: title, - ToolCallId: id, - } - for _, opt := range opts { - opt(&tc) - } - return SessionUpdate{ToolCall: &tc} -} - -// WithStartKind sets the kind for a tool_call start update. -func WithStartKind(k ToolKind) ToolCallStartOpt { - return func(tc *SessionUpdateToolCall) { - tc.Kind = k - } -} - -// WithStartStatus sets the status for a tool_call start update. -func WithStartStatus(s ToolCallStatus) ToolCallStartOpt { - return func(tc *SessionUpdateToolCall) { - tc.Status = s - } -} - -// WithStartContent sets the initial content for a tool_call start update. -func WithStartContent(c []ToolCallContent) ToolCallStartOpt { - return func(tc *SessionUpdateToolCall) { - tc.Content = c - } -} - -// WithStartLocations sets file locations and, if a single path is provided and rawInput is empty, mirrors it as rawInput.path. -func WithStartLocations(l []ToolCallLocation) ToolCallStartOpt { - return func(tc *SessionUpdateToolCall) { - tc.Locations = l - if len(l) == 1 && l[0].Path != "" { - if tc.RawInput == nil { - tc.RawInput = map[string]any{"path": l[0].Path} - } else { - m, ok := tc.RawInput.(map[string]any) - if ok { - if _, exists := m["path"]; !exists { - m["path"] = l[0].Path - } - } - } - } - } -} - -// WithStartRawInput sets rawInput for a tool_call start update. -func WithStartRawInput(v any) ToolCallStartOpt { - return func(tc *SessionUpdateToolCall) { - tc.RawInput = v - } -} - -// WithStartRawOutput sets rawOutput for a tool_call start update. -func WithStartRawOutput(v any) ToolCallStartOpt { - return func(tc *SessionUpdateToolCall) { - tc.RawOutput = v - } -} - -type ToolCallUpdateOpt func(tu *SessionUpdateToolCallUpdate) - -// UpdateToolCall constructs a tool_call_update with the given ID and applies optional modifiers. -func UpdateToolCall(id ToolCallId, opts ...ToolCallUpdateOpt) SessionUpdate { - tu := SessionUpdateToolCallUpdate{ToolCallId: id} - for _, opt := range opts { - opt(&tu) - } - return SessionUpdate{ToolCallUpdate: &tu} -} - -// WithUpdateTitle sets the title for a tool_call_update. -func WithUpdateTitle(t string) ToolCallUpdateOpt { - return func(tu *SessionUpdateToolCallUpdate) { - tu.Title = Ptr(t) - } -} - -// WithUpdateKind sets the kind for a tool_call_update. -func WithUpdateKind(k ToolKind) ToolCallUpdateOpt { - return func(tu *SessionUpdateToolCallUpdate) { - tu.Kind = Ptr(k) - } -} - -// WithUpdateStatus sets the status for a tool_call_update. -func WithUpdateStatus(s ToolCallStatus) ToolCallUpdateOpt { - return func(tu *SessionUpdateToolCallUpdate) { - tu.Status = Ptr(s) - } -} - -// WithUpdateContent replaces the content collection for a tool_call_update. -func WithUpdateContent(c []ToolCallContent) ToolCallUpdateOpt { - return func(tu *SessionUpdateToolCallUpdate) { - tu.Content = c - } -} - -// WithUpdateLocations replaces the locations collection for a tool_call_update. -func WithUpdateLocations(l []ToolCallLocation) ToolCallUpdateOpt { - return func(tu *SessionUpdateToolCallUpdate) { - tu.Locations = l - } -} - -// WithUpdateRawInput sets rawInput for a tool_call_update. -func WithUpdateRawInput(v any) ToolCallUpdateOpt { - return func(tu *SessionUpdateToolCallUpdate) { - tu.RawInput = v - } -} - -// WithUpdateRawOutput sets rawOutput for a tool_call_update. -func WithUpdateRawOutput(v any) ToolCallUpdateOpt { - return func(tu *SessionUpdateToolCallUpdate) { - tu.RawOutput = v - } -} - -// NewContentBlockText constructs a ContentBlock using the 'text' variant. -func NewContentBlockText(text string) ContentBlock { - return ContentBlock{Text: &ContentBlockText{ - Text: text, - Type: "text", - }} -} - -// NewContentBlockImage constructs a ContentBlock using the 'image' variant. -func NewContentBlockImage(data string, mimeType string) ContentBlock { - return ContentBlock{Image: &ContentBlockImage{ - Data: data, - MimeType: mimeType, - Type: "image", - }} -} - -// NewContentBlockAudio constructs a ContentBlock using the 'audio' variant. -func NewContentBlockAudio(data string, mimeType string) ContentBlock { - return ContentBlock{Audio: &ContentBlockAudio{ - Data: data, - MimeType: mimeType, - Type: "audio", - }} -} - -// NewContentBlockResourceLink constructs a ContentBlock using the 'resource_link' variant. -func NewContentBlockResourceLink(name string, uri string) ContentBlock { - return ContentBlock{ResourceLink: &ContentBlockResourceLink{ - Name: name, - Type: "resource_link", - Uri: uri, - }} -} - -// NewContentBlockResource constructs a ContentBlock using the 'resource' variant. -func NewContentBlockResource(resource EmbeddedResourceResource) ContentBlock { - return ContentBlock{Resource: &ContentBlockResource{ - Resource: resource, - Type: "resource", - }} -} - // NewRequestPermissionOutcomeCancelled constructs a RequestPermissionOutcome using the 'cancelled' variant. func NewRequestPermissionOutcomeCancelled() RequestPermissionOutcome { return RequestPermissionOutcome{Cancelled: &RequestPermissionOutcomeCancelled{Outcome: "cancelled"}} @@ -299,94 +14,3 @@ func NewRequestPermissionOutcomeSelected(optionId PermissionOptionId) RequestPer Outcome: "selected", }} } - -// NewSessionUpdateUserMessageChunk constructs a SessionUpdate using the 'user_message_chunk' variant. -func NewSessionUpdateUserMessageChunk(content ContentBlock) SessionUpdate { - return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{ - Content: content, - SessionUpdate: "user_message_chunk", - }} -} - -// NewSessionUpdateAgentMessageChunk constructs a SessionUpdate using the 'agent_message_chunk' variant. -func NewSessionUpdateAgentMessageChunk(content ContentBlock) SessionUpdate { - return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{ - Content: content, - SessionUpdate: "agent_message_chunk", - }} -} - -// NewSessionUpdateAgentThoughtChunk constructs a SessionUpdate using the 'agent_thought_chunk' variant. -func NewSessionUpdateAgentThoughtChunk(content ContentBlock) SessionUpdate { - return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{ - Content: content, - SessionUpdate: "agent_thought_chunk", - }} -} - -// NewSessionUpdateToolCall constructs a SessionUpdate using the 'tool_call' variant. -func NewSessionUpdateToolCall(toolCallId ToolCallId, title string) SessionUpdate { - return SessionUpdate{ToolCall: &SessionUpdateToolCall{ - SessionUpdate: "tool_call", - Title: title, - ToolCallId: toolCallId, - }} -} - -// NewSessionUpdateToolCallUpdate constructs a SessionUpdate using the 'tool_call_update' variant. -func NewSessionUpdateToolCallUpdate(toolCallId ToolCallId) SessionUpdate { - return SessionUpdate{ToolCallUpdate: &SessionUpdateToolCallUpdate{ - SessionUpdate: "tool_call_update", - ToolCallId: toolCallId, - }} -} - -// NewSessionUpdatePlan constructs a SessionUpdate using the 'plan' variant. -func NewSessionUpdatePlan(entries []PlanEntry) SessionUpdate { - return SessionUpdate{Plan: &SessionUpdatePlan{ - Entries: entries, - SessionUpdate: "plan", - }} -} - -// NewToolCallContentContent constructs a ToolCallContent using the 'content' variant. -func NewToolCallContentContent(content ContentBlock) ToolCallContent { - return ToolCallContent{Content: &ToolCallContentContent{ - Content: content, - Type: "content", - }} -} - -// NewToolCallContentDiff constructs a ToolCallContent using the 'diff' variant. -func NewToolCallContentDiff(path string, newText string) ToolCallContent { - return ToolCallContent{Diff: &ToolCallContentDiff{ - NewText: newText, - Path: path, - Type: "diff", - }} -} - -// NewToolCallContentTerminal constructs a ToolCallContent using the 'terminal' variant. -func NewToolCallContentTerminal(terminalId string) ToolCallContent { - return ToolCallContent{Terminal: &ToolCallContentTerminal{ - TerminalId: terminalId, - Type: "terminal", - }} -} - -// StartReadToolCall constructs a 'tool_call' update for reading a file: kind=read, status=pending, locations=[{path}], rawInput={path}. -func StartReadToolCall(id ToolCallId, title string, path string, opts ...ToolCallStartOpt) SessionUpdate { - base := []ToolCallStartOpt{WithStartKind(ToolKindRead), WithStartStatus(ToolCallStatusPending), WithStartLocations([]ToolCallLocation{ToolCallLocation{Path: path}}), WithStartRawInput(map[string]any{"path": path})} - args := append(base, opts...) - return StartToolCall(id, title, args...) -} - -// StartEditToolCall constructs a 'tool_call' update for editing content: kind=edit, status=pending, locations=[{path}], rawInput={path, content}. -func StartEditToolCall(id ToolCallId, title string, path string, content any, opts ...ToolCallStartOpt) SessionUpdate { - base := []ToolCallStartOpt{WithStartKind(ToolKindEdit), WithStartStatus(ToolCallStatusPending), WithStartLocations([]ToolCallLocation{ToolCallLocation{Path: path}}), WithStartRawInput(map[string]any{ - "content": content, - "path": path, - })} - args := append(base, opts...) - return StartToolCall(id, title, args...) -} diff --git a/go/json_parity_test.go b/go/json_parity_test.go index b2d4590..ed1f620 100644 --- a/go/json_parity_test.go +++ b/go/json_parity_test.go @@ -73,37 +73,24 @@ func runGolden[T any](builds ...func() T) func(t *testing.T) { func TestJSONGolden_ContentBlocks(t *testing.T) { t.Run("content_text", runGolden( func() ContentBlock { return TextBlock("What's the weather like today?") }, - func() ContentBlock { return NewContentBlockText("What's the weather like today?") }, )) t.Run("content_image", runGolden( func() ContentBlock { return ImageBlock("iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB...", "image/png") }, - func() ContentBlock { return NewContentBlockImage("iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB...", "image/png") }, )) t.Run("content_audio", runGolden( func() ContentBlock { return AudioBlock("UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB...", "audio/wav") }, - func() ContentBlock { - return NewContentBlockAudio("UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB...", "audio/wav") - }, )) t.Run("content_resource_text", runGolden( func() ContentBlock { res := EmbeddedResourceResource{TextResourceContents: &TextResourceContents{Uri: "file:///home/user/script.py", MimeType: Ptr("text/x-python"), Text: "def hello():\n print('Hello, world!')"}} return ResourceBlock(EmbeddedResource{Resource: res}) }, - func() ContentBlock { - res := EmbeddedResourceResource{TextResourceContents: &TextResourceContents{Uri: "file:///home/user/script.py", MimeType: Ptr("text/x-python"), Text: "def hello():\n print('Hello, world!')"}} - return NewContentBlockResource(res) - }, )) t.Run("content_resource_blob", runGolden( func() ContentBlock { res := EmbeddedResourceResource{BlobResourceContents: &BlobResourceContents{Uri: "file:///home/user/document.pdf", MimeType: Ptr("application/pdf"), Blob: ""}} return ResourceBlock(EmbeddedResource{Resource: res}) }, - func() ContentBlock { - res := EmbeddedResourceResource{BlobResourceContents: &BlobResourceContents{Uri: "file:///home/user/document.pdf", MimeType: Ptr("application/pdf"), Blob: ""}} - return NewContentBlockResource(res) - }, )) t.Run("content_resource_link", runGolden( func() ContentBlock { @@ -119,23 +106,12 @@ func TestJSONGolden_ContentBlocks(t *testing.T) { cb.ResourceLink.Size = &sz return cb }, - func() ContentBlock { - cb := NewContentBlockResourceLink("document.pdf", "file:///home/user/document.pdf") - mt := "application/pdf" - sz := 1024000 - cb.ResourceLink.MimeType = &mt - cb.ResourceLink.Size = &sz - return cb - }, )) } func TestJSONGolden_ToolCallContent(t *testing.T) { t.Run("tool_content_content_text", runGolden( func() ToolCallContent { return ToolContent(TextBlock("Analysis complete. Found 3 issues.")) }, - func() ToolCallContent { - return NewToolCallContentContent(TextBlock("Analysis complete. Found 3 issues.")) - }, )) t.Run("tool_content_diff", runGolden(func() ToolCallContent { old := "{\n \"debug\": false\n}" @@ -145,13 +121,9 @@ func TestJSONGolden_ToolCallContent(t *testing.T) { func() ToolCallContent { return ToolDiffContent("/home/user/project/src/config.json", "{\n \"debug\": true\n}") }, - func() ToolCallContent { - return NewToolCallContentDiff("/home/user/project/src/config.json", "{\n \"debug\": true\n}") - }, )) t.Run("tool_content_terminal", runGolden( func() ToolCallContent { return ToolTerminalRef("term_001") }, - func() ToolCallContent { return NewToolCallContentTerminal("term_001") }, )) } @@ -178,27 +150,18 @@ func TestJSONGolden_SessionUpdates(t *testing.T) { return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: TextBlock("What's the capital of France?")}} }, func() SessionUpdate { return UpdateUserMessageText("What's the capital of France?") }, - func() SessionUpdate { - return NewSessionUpdateUserMessageChunk(TextBlock("What's the capital of France?")) - }, )) t.Run("session_update_agent_message_chunk", runGolden( func() SessionUpdate { return SessionUpdate{AgentMessageChunk: &SessionUpdateAgentMessageChunk{Content: TextBlock("The capital of France is Paris.")}} }, func() SessionUpdate { return UpdateAgentMessageText("The capital of France is Paris.") }, - func() SessionUpdate { - return NewSessionUpdateAgentMessageChunk(TextBlock("The capital of France is Paris.")) - }, )) t.Run("session_update_agent_thought_chunk", runGolden( func() SessionUpdate { return SessionUpdate{AgentThoughtChunk: &SessionUpdateAgentThoughtChunk{Content: TextBlock("Thinking about best approach...")}} }, func() SessionUpdate { return UpdateAgentThoughtText("Thinking about best approach...") }, - func() SessionUpdate { - return NewSessionUpdateAgentThoughtChunk(TextBlock("Thinking about best approach...")) - }, )) t.Run("session_update_plan", runGolden( func() SessionUpdate { @@ -210,9 +173,6 @@ func TestJSONGolden_SessionUpdates(t *testing.T) { PlanEntry{Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}, ) }, - func() SessionUpdate { - return NewSessionUpdatePlan([]PlanEntry{{Content: "Check for syntax errors", Priority: PlanEntryPriorityHigh, Status: PlanEntryStatusPending}, {Content: "Identify potential type issues", Priority: PlanEntryPriorityMedium, Status: PlanEntryStatusPending}}) - }, )) t.Run("session_update_tool_call", runGolden( func() SessionUpdate { @@ -222,9 +182,7 @@ func TestJSONGolden_SessionUpdates(t *testing.T) { return StartToolCall("call_001", "Reading configuration file", WithStartKind(ToolKindRead), WithStartStatus(ToolCallStatusPending)) }, )) - t.Run("session_update_tool_call_minimal", runGolden( - func() SessionUpdate { return NewSessionUpdateToolCall("call_001", "Reading configuration file") }, - )) + // Removed: session_update_tool_call_minimal (deprecated New helper) t.Run("session_update_tool_call_read", runGolden( func() SessionUpdate { return StartReadToolCall("call_001", "Reading configuration file", "/home/user/project/src/config.json") @@ -248,9 +206,7 @@ func TestJSONGolden_SessionUpdates(t *testing.T) { return UpdateToolCall("call_001", WithUpdateStatus(ToolCallStatusInProgress), WithUpdateContent([]ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))})) }, )) - t.Run("session_update_tool_call_update_minimal", runGolden( - func() SessionUpdate { return NewSessionUpdateToolCallUpdate("call_001") }, - )) + // Removed: session_update_tool_call_update_minimal (deprecated New helper) t.Run("session_update_tool_call_update_more_fields", runGolden( func() SessionUpdate { return UpdateToolCall( diff --git a/go/testdata/json_golden/session_update_tool_call_minimal.json b/go/testdata/json_golden/session_update_tool_call_minimal.json deleted file mode 100644 index af5edd7..0000000 --- a/go/testdata/json_golden/session_update_tool_call_minimal.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "sessionUpdate": "tool_call", - "toolCallId": "call_001", - "title": "Reading configuration file" -} - diff --git a/go/testdata/json_golden/session_update_tool_call_update_minimal.json b/go/testdata/json_golden/session_update_tool_call_update_minimal.json deleted file mode 100644 index 4493e55..0000000 --- a/go/testdata/json_golden/session_update_tool_call_update_minimal.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "sessionUpdate": "tool_call_update", - "toolCallId": "call_001" -} - From 43d82f008ab71457aa2a290c3e7bbcea7e25d746 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Tue, 2 Sep 2025 21:10:26 +0200 Subject: [PATCH 17/22] test: remove trailing newlines from JSON golden files Change-Id: I07e46aa00b16f50b44e8e56b614c125cbf8bcb47 Signed-off-by: Thomas Kosiewski --- go/testdata/json_golden/cancel_notification.json | 1 - go/testdata/json_golden/content_audio.json | 1 - go/testdata/json_golden/content_image.json | 1 - go/testdata/json_golden/content_resource_blob.json | 1 - go/testdata/json_golden/content_resource_link.json | 1 - go/testdata/json_golden/content_resource_text.json | 1 - go/testdata/json_golden/content_text.json | 1 - go/testdata/json_golden/fs_read_text_file_request.json | 1 - go/testdata/json_golden/fs_read_text_file_response.json | 1 - go/testdata/json_golden/fs_write_text_file_request.json | 1 - go/testdata/json_golden/initialize_request.json | 1 - go/testdata/json_golden/initialize_response.json | 1 - go/testdata/json_golden/new_session_request.json | 1 - go/testdata/json_golden/new_session_response.json | 1 - go/testdata/json_golden/permission_outcome_cancelled.json | 1 - go/testdata/json_golden/permission_outcome_selected.json | 1 - go/testdata/json_golden/prompt_request.json | 1 - go/testdata/json_golden/request_permission_request.json | 1 - .../json_golden/request_permission_response_selected.json | 1 - go/testdata/json_golden/session_update_agent_message_chunk.json | 1 - go/testdata/json_golden/session_update_agent_thought_chunk.json | 1 - go/testdata/json_golden/session_update_plan.json | 1 - go/testdata/json_golden/session_update_tool_call.json | 1 - go/testdata/json_golden/session_update_tool_call_edit.json | 1 - .../json_golden/session_update_tool_call_locations_rawinput.json | 1 - go/testdata/json_golden/session_update_tool_call_read.json | 1 - .../json_golden/session_update_tool_call_update_content.json | 1 - .../json_golden/session_update_tool_call_update_more_fields.json | 1 - go/testdata/json_golden/session_update_user_message_chunk.json | 1 - go/testdata/json_golden/tool_content_content_text.json | 1 - go/testdata/json_golden/tool_content_diff.json | 1 - go/testdata/json_golden/tool_content_diff_no_old.json | 1 - go/testdata/json_golden/tool_content_terminal.json | 1 - 33 files changed, 33 deletions(-) diff --git a/go/testdata/json_golden/cancel_notification.json b/go/testdata/json_golden/cancel_notification.json index f1b3079..a5461d2 100644 --- a/go/testdata/json_golden/cancel_notification.json +++ b/go/testdata/json_golden/cancel_notification.json @@ -1,4 +1,3 @@ { "sessionId": "sess_abc123def456" } - diff --git a/go/testdata/json_golden/content_audio.json b/go/testdata/json_golden/content_audio.json index 6474c1d..6cd650e 100644 --- a/go/testdata/json_golden/content_audio.json +++ b/go/testdata/json_golden/content_audio.json @@ -3,4 +3,3 @@ "mimeType": "audio/wav", "data": "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB..." } - diff --git a/go/testdata/json_golden/content_image.json b/go/testdata/json_golden/content_image.json index 87fd55c..fca8b88 100644 --- a/go/testdata/json_golden/content_image.json +++ b/go/testdata/json_golden/content_image.json @@ -3,4 +3,3 @@ "mimeType": "image/png", "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB..." } - diff --git a/go/testdata/json_golden/content_resource_blob.json b/go/testdata/json_golden/content_resource_blob.json index 121eda3..4832503 100644 --- a/go/testdata/json_golden/content_resource_blob.json +++ b/go/testdata/json_golden/content_resource_blob.json @@ -6,4 +6,3 @@ "blob": "" } } - diff --git a/go/testdata/json_golden/content_resource_link.json b/go/testdata/json_golden/content_resource_link.json index 82a0a30..4e33c1e 100644 --- a/go/testdata/json_golden/content_resource_link.json +++ b/go/testdata/json_golden/content_resource_link.json @@ -5,4 +5,3 @@ "mimeType": "application/pdf", "size": 1024000 } - diff --git a/go/testdata/json_golden/content_resource_text.json b/go/testdata/json_golden/content_resource_text.json index 80f40fa..f73945a 100644 --- a/go/testdata/json_golden/content_resource_text.json +++ b/go/testdata/json_golden/content_resource_text.json @@ -6,4 +6,3 @@ "text": "def hello():\n print('Hello, world!')" } } - diff --git a/go/testdata/json_golden/content_text.json b/go/testdata/json_golden/content_text.json index eca6047..63b2e85 100644 --- a/go/testdata/json_golden/content_text.json +++ b/go/testdata/json_golden/content_text.json @@ -2,4 +2,3 @@ "type": "text", "text": "What's the weather like today?" } - diff --git a/go/testdata/json_golden/fs_read_text_file_request.json b/go/testdata/json_golden/fs_read_text_file_request.json index ec25ffa..3d3ccca 100644 --- a/go/testdata/json_golden/fs_read_text_file_request.json +++ b/go/testdata/json_golden/fs_read_text_file_request.json @@ -4,4 +4,3 @@ "line": 10, "limit": 50 } - diff --git a/go/testdata/json_golden/fs_read_text_file_response.json b/go/testdata/json_golden/fs_read_text_file_response.json index 68c73a3..b5dac57 100644 --- a/go/testdata/json_golden/fs_read_text_file_response.json +++ b/go/testdata/json_golden/fs_read_text_file_response.json @@ -1,4 +1,3 @@ { "content": "def hello_world():\n print('Hello, world!')\n" } - diff --git a/go/testdata/json_golden/fs_write_text_file_request.json b/go/testdata/json_golden/fs_write_text_file_request.json index eb72e39..efbad09 100644 --- a/go/testdata/json_golden/fs_write_text_file_request.json +++ b/go/testdata/json_golden/fs_write_text_file_request.json @@ -3,4 +3,3 @@ "path": "/home/user/project/config.json", "content": "{\n \"debug\": true,\n \"version\": \"1.0.0\"\n}" } - diff --git a/go/testdata/json_golden/initialize_request.json b/go/testdata/json_golden/initialize_request.json index 5f689c7..b239909 100644 --- a/go/testdata/json_golden/initialize_request.json +++ b/go/testdata/json_golden/initialize_request.json @@ -7,4 +7,3 @@ } } } - diff --git a/go/testdata/json_golden/initialize_response.json b/go/testdata/json_golden/initialize_response.json index 499e8fa..6524b96 100644 --- a/go/testdata/json_golden/initialize_response.json +++ b/go/testdata/json_golden/initialize_response.json @@ -10,4 +10,3 @@ }, "authMethods": [] } - diff --git a/go/testdata/json_golden/new_session_request.json b/go/testdata/json_golden/new_session_request.json index 271df0f..27f57c2 100644 --- a/go/testdata/json_golden/new_session_request.json +++ b/go/testdata/json_golden/new_session_request.json @@ -9,4 +9,3 @@ } ] } - diff --git a/go/testdata/json_golden/new_session_response.json b/go/testdata/json_golden/new_session_response.json index f1b3079..a5461d2 100644 --- a/go/testdata/json_golden/new_session_response.json +++ b/go/testdata/json_golden/new_session_response.json @@ -1,4 +1,3 @@ { "sessionId": "sess_abc123def456" } - diff --git a/go/testdata/json_golden/permission_outcome_cancelled.json b/go/testdata/json_golden/permission_outcome_cancelled.json index 3fe3090..38f0331 100644 --- a/go/testdata/json_golden/permission_outcome_cancelled.json +++ b/go/testdata/json_golden/permission_outcome_cancelled.json @@ -1,4 +1,3 @@ { "outcome": "cancelled" } - diff --git a/go/testdata/json_golden/permission_outcome_selected.json b/go/testdata/json_golden/permission_outcome_selected.json index 3c79f8f..3a194c2 100644 --- a/go/testdata/json_golden/permission_outcome_selected.json +++ b/go/testdata/json_golden/permission_outcome_selected.json @@ -2,4 +2,3 @@ "outcome": "selected", "optionId": "allow-once" } - diff --git a/go/testdata/json_golden/prompt_request.json b/go/testdata/json_golden/prompt_request.json index c4da8fb..816fae1 100644 --- a/go/testdata/json_golden/prompt_request.json +++ b/go/testdata/json_golden/prompt_request.json @@ -15,4 +15,3 @@ } ] } - diff --git a/go/testdata/json_golden/request_permission_request.json b/go/testdata/json_golden/request_permission_request.json index f845f99..1fb297f 100644 --- a/go/testdata/json_golden/request_permission_request.json +++ b/go/testdata/json_golden/request_permission_request.json @@ -16,4 +16,3 @@ } ] } - diff --git a/go/testdata/json_golden/request_permission_response_selected.json b/go/testdata/json_golden/request_permission_response_selected.json index 98df2fc..e29b89b 100644 --- a/go/testdata/json_golden/request_permission_response_selected.json +++ b/go/testdata/json_golden/request_permission_response_selected.json @@ -4,4 +4,3 @@ "optionId": "allow-once" } } - diff --git a/go/testdata/json_golden/session_update_agent_message_chunk.json b/go/testdata/json_golden/session_update_agent_message_chunk.json index 9b9f6d0..7ace7ed 100644 --- a/go/testdata/json_golden/session_update_agent_message_chunk.json +++ b/go/testdata/json_golden/session_update_agent_message_chunk.json @@ -5,4 +5,3 @@ "text": "The capital of France is Paris." } } - diff --git a/go/testdata/json_golden/session_update_agent_thought_chunk.json b/go/testdata/json_golden/session_update_agent_thought_chunk.json index b331119..893c13b 100644 --- a/go/testdata/json_golden/session_update_agent_thought_chunk.json +++ b/go/testdata/json_golden/session_update_agent_thought_chunk.json @@ -5,4 +5,3 @@ "text": "Thinking about best approach..." } } - diff --git a/go/testdata/json_golden/session_update_plan.json b/go/testdata/json_golden/session_update_plan.json index 744b7cd..bad3e8a 100644 --- a/go/testdata/json_golden/session_update_plan.json +++ b/go/testdata/json_golden/session_update_plan.json @@ -13,4 +13,3 @@ } ] } - diff --git a/go/testdata/json_golden/session_update_tool_call.json b/go/testdata/json_golden/session_update_tool_call.json index 0ad4ce4..448649d 100644 --- a/go/testdata/json_golden/session_update_tool_call.json +++ b/go/testdata/json_golden/session_update_tool_call.json @@ -5,4 +5,3 @@ "kind": "read", "status": "pending" } - diff --git a/go/testdata/json_golden/session_update_tool_call_edit.json b/go/testdata/json_golden/session_update_tool_call_edit.json index 6ddd93d..1cf0bda 100644 --- a/go/testdata/json_golden/session_update_tool_call_edit.json +++ b/go/testdata/json_golden/session_update_tool_call_edit.json @@ -14,4 +14,3 @@ "content": "print('hello')" } } - diff --git a/go/testdata/json_golden/session_update_tool_call_locations_rawinput.json b/go/testdata/json_golden/session_update_tool_call_locations_rawinput.json index d76507a..a1ac3e4 100644 --- a/go/testdata/json_golden/session_update_tool_call_locations_rawinput.json +++ b/go/testdata/json_golden/session_update_tool_call_locations_rawinput.json @@ -11,4 +11,3 @@ "path": "/home/user/project/src/config.json" } } - diff --git a/go/testdata/json_golden/session_update_tool_call_read.json b/go/testdata/json_golden/session_update_tool_call_read.json index bfc2008..d533afb 100644 --- a/go/testdata/json_golden/session_update_tool_call_read.json +++ b/go/testdata/json_golden/session_update_tool_call_read.json @@ -13,4 +13,3 @@ "path": "/home/user/project/src/config.json" } } - diff --git a/go/testdata/json_golden/session_update_tool_call_update_content.json b/go/testdata/json_golden/session_update_tool_call_update_content.json index 6db93a5..e28b461 100644 --- a/go/testdata/json_golden/session_update_tool_call_update_content.json +++ b/go/testdata/json_golden/session_update_tool_call_update_content.json @@ -12,4 +12,3 @@ } ] } - diff --git a/go/testdata/json_golden/session_update_tool_call_update_more_fields.json b/go/testdata/json_golden/session_update_tool_call_update_more_fields.json index 1469cae..d5af335 100644 --- a/go/testdata/json_golden/session_update_tool_call_update_more_fields.json +++ b/go/testdata/json_golden/session_update_tool_call_update_more_fields.json @@ -25,4 +25,3 @@ } ] } - diff --git a/go/testdata/json_golden/session_update_user_message_chunk.json b/go/testdata/json_golden/session_update_user_message_chunk.json index 7944b98..8ca73e7 100644 --- a/go/testdata/json_golden/session_update_user_message_chunk.json +++ b/go/testdata/json_golden/session_update_user_message_chunk.json @@ -5,4 +5,3 @@ "text": "What's the capital of France?" } } - diff --git a/go/testdata/json_golden/tool_content_content_text.json b/go/testdata/json_golden/tool_content_content_text.json index 820a413..bf3b6f7 100644 --- a/go/testdata/json_golden/tool_content_content_text.json +++ b/go/testdata/json_golden/tool_content_content_text.json @@ -5,4 +5,3 @@ "text": "Analysis complete. Found 3 issues." } } - diff --git a/go/testdata/json_golden/tool_content_diff.json b/go/testdata/json_golden/tool_content_diff.json index c1755dc..98482cb 100644 --- a/go/testdata/json_golden/tool_content_diff.json +++ b/go/testdata/json_golden/tool_content_diff.json @@ -4,4 +4,3 @@ "oldText": "{\n \"debug\": false\n}", "newText": "{\n \"debug\": true\n}" } - diff --git a/go/testdata/json_golden/tool_content_diff_no_old.json b/go/testdata/json_golden/tool_content_diff_no_old.json index e14cbe9..c044187 100644 --- a/go/testdata/json_golden/tool_content_diff_no_old.json +++ b/go/testdata/json_golden/tool_content_diff_no_old.json @@ -3,4 +3,3 @@ "path": "/home/user/project/src/config.json", "newText": "{\n \"debug\": true\n}" } - diff --git a/go/testdata/json_golden/tool_content_terminal.json b/go/testdata/json_golden/tool_content_terminal.json index 387b7d8..fd0c676 100644 --- a/go/testdata/json_golden/tool_content_terminal.json +++ b/go/testdata/json_golden/tool_content_terminal.json @@ -2,4 +2,3 @@ "type": "terminal", "terminalId": "term_001" } - From 913fa9c4f217159ec9119debe6b591d2f49041d4 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Wed, 10 Sep 2025 12:35:12 +0200 Subject: [PATCH 18/22] feat: add Nix development environment with JSON Schema defaults and Go module restructure Change-Id: Ic03480847e6d562067e0b170eb72fe6ea39efc20 Signed-off-by: Thomas Kosiewski --- .gitignore | 3 + flake.lock | 96 ++++++++++ flake.nix | 55 ++++++ go.mod | 3 + go/cmd/generate/internal/emit/jenwrap.go | 1 + go/cmd/generate/internal/emit/types.go | 144 +++++++++++++- go/cmd/generate/internal/load/load.go | 3 + go/defaults_test.go | 129 +++++++++++++ go/example/agent/main.go | 34 +++- go/example/client/main.go | 22 ++- go/go.mod | 3 - go/json_parity_test.go | 10 +- go/types_gen.go | 233 ++++++++++++++++++++++- package.json | 10 +- rust/markdown_generator.rs | 32 +++- 15 files changed, 748 insertions(+), 30 deletions(-) create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 go.mod create mode 100644 go/defaults_test.go delete mode 100644 go/go.mod diff --git a/.gitignore b/.gitignore index 2e7ea93..42c9116 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ typescript/docs/ # Go files .gocache .gopath + +.envrc +.direnv diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..5fd14c9 --- /dev/null +++ b/flake.lock @@ -0,0 +1,96 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1757244434, + "narHash": "sha256-AeqTqY0Y95K1Fgs6wuT1LafBNcmKxcOkWnm4alD9pqM=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "092c565d333be1e17b4779ac22104338941d913f", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-25.05", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1744536153, + "narHash": "sha256-awS2zRgF4uTwrOKwwiJcByDzDOdo3Q1rPZbiHQg/N38=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "18dd725c29603f582cf1900e0d25f9f1063dbf11", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "rust-overlay": "rust-overlay" + } + }, + "rust-overlay": { + "inputs": { + "nixpkgs": "nixpkgs_2" + }, + "locked": { + "lastModified": 1757298987, + "narHash": "sha256-yuFSw6fpfjPtVMmym51ozHYpJQ7SzVOTkk7tUv2JA0U=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "cfd63776bde44438ff2936f0c9194c79dd407a5f", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..094b83c --- /dev/null +++ b/flake.nix @@ -0,0 +1,55 @@ +{ + description = "Devshell for ACP"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05"; + flake-utils.url = "github:numtide/flake-utils"; + rust-overlay.url = "github:oxalica/rust-overlay"; + }; + + outputs = + { + self, + nixpkgs, + flake-utils, + rust-overlay, + }: + flake-utils.lib.eachDefaultSystem ( + system: + let + pkgs = import nixpkgs { + inherit system; + overlays = [ rust-overlay.overlays.default ]; + }; + + formatter = pkgs.nixfmt-rfc-style; + + # Rust toolchain derived from rust-toolchain.toml + # Uses oxalica/rust-overlay to match the pinned channel/components. + rustToolchain = pkgs.rust-bin.fromRustupToolchainFile ./rust-toolchain.toml; + in + { + inherit formatter; + + devShells.default = pkgs.mkShell { + packages = with pkgs; [ + # Rust toolchain pinned via rust-toolchain.toml + rustToolchain + pkg-config + openssl + + # Node.js toolchain + nodejs_24 + + # Go toolchain + go_1_24 + + # Nix formatter + formatter + ]; + + RUST_BACKTRACE = "1"; + }; + } + ); +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f23c3f4 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/zed-industries/agent-client-protocol + +go 1.21 diff --git a/go/cmd/generate/internal/emit/jenwrap.go b/go/cmd/generate/internal/emit/jenwrap.go index 6691a72..07e8085 100644 --- a/go/cmd/generate/internal/emit/jenwrap.go +++ b/go/cmd/generate/internal/emit/jenwrap.go @@ -14,6 +14,7 @@ var ( NewFile = jen.NewFile Id = jen.Id Lit = jen.Lit + Line = jen.Line Func = jen.Func For = jen.For Range = jen.Range diff --git a/go/cmd/generate/internal/emit/types.go b/go/cmd/generate/internal/emit/types.go index 39855c8..aa7a608 100644 --- a/go/cmd/generate/internal/emit/types.go +++ b/go/cmd/generate/internal/emit/types.go @@ -2,6 +2,7 @@ package emit import ( "bytes" + "encoding/json" "fmt" "os" "path/filepath" @@ -78,6 +79,24 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { pkeys = append(pkeys, pk) } sort.Strings(pkeys) + // Track fields with schema defaults for generic (de)serialization + type DefaultKind int + const ( + KindNone DefaultKind = iota + KindScalar + KindArray + KindObject + ) + type defaultProp struct { + fieldName string + propName string + defaultJSON string + kind DefaultKind + allowNull bool + nilable bool // whether zero-value is nil (slice/map) + } + defaults := []defaultProp{} + for _, pk := range pkeys { prop := def.Properties[pk] field := util.ToExportedField(pk) @@ -85,17 +104,97 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { st = append(st, Comment(util.SanitizeComment(prop.Description))) } tag := pk + // Detect defaults generically + var dp *defaultProp + if prop.Default != nil { + // Compute kind from default value + k := defaultKindOf(prop.Default) + // Whether field zero is nil (slice/map) for Marshal fill-in + nilable := ir.PrimaryType(prop) == "array" || (ir.PrimaryType(prop) == "object" && len(prop.Properties) == 0 && prop.Ref == "") + // Capture canonical JSON of default + defJSON := "null" + if b, err := json.Marshal(prop.Default); err == nil { + defJSON = string(b) + } + dp = &defaultProp{ + fieldName: field, + propName: pk, + defaultJSON: defJSON, + kind: DefaultKind(k), + allowNull: includesNull(prop), + nilable: nilable, + } + defaults = append(defaults, *dp) + } if _, ok := req[pk]; !ok { - // Default: omit if empty, except for specific always-present fields - // Ensure InitializeResponse.authMethods is always encoded (even when empty) - if name != "InitializeResponse" || pk != "authMethods" { + // Default: omit if empty for optional fields, unless schema specifies + // a default array/object (always present on wire). + if dp == nil || (dp.kind != KindArray && dp.kind != KindObject) { tag = pk + ",omitempty" } } + // Emit an additional comment line indicating the default, if any. + if dp != nil && dp.defaultJSON != "null" { + // Insert an empty comment line before default comment (visual separator) + if prop.Description != "" { + st = append(st, Comment("")) + } + st = append(st, Comment(util.SanitizeComment(fmt.Sprintf("Defaults to %s if unset.", dp.defaultJSON)))) + } st = append(st, Id(field).Add(jenTypeForOptional(prop)).Tag(map[string]string{"json": tag})) } f.Type().Id(name).Struct(st...) f.Line() + + // If the struct has any fields with schema defaults, synthesize MarshalJSON and UnmarshalJSON + if len(defaults) > 0 { + // MarshalJSON: coerce nil slices to empty slices before encoding + f.Func().Params(Id("v").Id(name)).Id("MarshalJSON").Params().Params(Index().Byte(), Error()).BlockFunc(func(g *Group) { + g.Type().Id("Alias").Id(name) + g.Var().Id("a").Id("Alias") + g.Id("a").Op("=").Id("Alias").Call(Id("v")) + for _, dp := range defaults { + // For array/map defaults: if zero is nil, fill with default JSON when nil + if dp.kind == KindArray || dp.kind == KindObject { + if dp.nilable { + g.If(Id("a").Dot(dp.fieldName).Op("==").Nil()).Block( + Qual("encoding/json", "Unmarshal").Call(Index().Byte().Parens(Lit(dp.defaultJSON)), Op("&").Id("a").Dot(dp.fieldName)), + ) + } + } + // For typed object defaults (non-nilable), we keep Option A: do not inject values on encode. + } + g.Return(Qual("encoding/json", "Marshal").Call(Id("a"))) + }) + f.Line() + + // UnmarshalJSON: apply defaults when field is missing or null (and schema doesn't include null) + f.Func().Params(Id("v").Op("*").Id(name)).Id("UnmarshalJSON").Params(Id("b").Index().Byte()).Error().BlockFunc(func(g *Group) { + g.Var().Id("m").Map(String()).Qual("encoding/json", "RawMessage") + g.If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("m")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))) + g.Type().Id("Alias").Id(name) + g.Var().Id("a").Id("Alias") + g.If(List(Id("err")).Op(":=").Qual("encoding/json", "Unmarshal").Call(Id("b"), Op("&").Id("a")), Id("err").Op("!=").Nil()).Block(Return(Id("err"))) + for _, dp := range defaults { + g.BlockFunc(func(h *Group) { + h.List(Id("_rm"), Id("_ok")).Op(":=").Id("m").Index(Lit(dp.propName)) + // Apply default when missing, or when null and null is not allowed + if dp.allowNull { + h.If(Op("!").Id("_ok")).Block( + Qual("encoding/json", "Unmarshal").Call(Index().Byte().Parens(Lit(dp.defaultJSON)), Op("&").Id("a").Dot(dp.fieldName)), + ) + } else { + h.If(Op("!").Id("_ok").Op("||").Parens(Id("string").Call(Id("_rm")).Op("==").Lit("null"))).Block( + Qual("encoding/json", "Unmarshal").Call(Index().Byte().Parens(Lit(dp.defaultJSON)), Op("&").Id("a").Dot(dp.fieldName)), + ) + } + }) + } + g.Op("*").Id("v").Op("=").Id(name).Call(Id("a")) + g.Return(Nil()) + }) + f.Line() + } case ir.PrimaryType(def) == "string" || ir.PrimaryType(def) == "integer" || ir.PrimaryType(def) == "number" || ir.PrimaryType(def) == "boolean": f.Type().Id(name).Add(primitiveJenType(ir.PrimaryType(def))) f.Line() @@ -272,6 +371,45 @@ func primitiveJenType(t string) Code { } } +// defaultKindOf classifies the JSON Schema default value into a coarse kind. +func defaultKindOf(val any) int { + switch val.(type) { + case nil: + return 0 // KindNone + case []any: + return 2 // KindArray + case map[string]any: + return 3 // KindObject + case string, float64, bool: + return 1 // KindScalar + default: + // Fallback: classify by fmt string + s := fmt.Sprint(val) + if strings.HasPrefix(s, "[") { + return 2 + } + if strings.HasPrefix(s, "map[") || strings.HasPrefix(s, "{") { + return 3 + } + return 1 + } +} + +// includesNull reports whether the property's type union contains null. +func includesNull(d *load.Definition) bool { + if d == nil || d.Type == nil { + return false + } + if arr, ok := d.Type.([]any); ok { + for _, v := range arr { + if s, ok2 := v.(string); ok2 && s == "null" { + return true + } + } + } + return false +} + func jenTypeFor(d *load.Definition) Code { if d == nil { return Any() diff --git a/go/cmd/generate/internal/load/load.go b/go/cmd/generate/internal/load/load.go index e9a044f..ac7b33a 100644 --- a/go/cmd/generate/internal/load/load.go +++ b/go/cmd/generate/internal/load/load.go @@ -37,6 +37,9 @@ type Definition struct { Const any `json:"const"` XSide string `json:"x-side"` XMethod string `json:"x-method"` + // Default holds the JSON Schema default value, when present. + // Used by generators to synthesize defaulting behavior. + Default any `json:"default"` } // ReadMeta loads schema/meta.json. diff --git a/go/defaults_test.go b/go/defaults_test.go new file mode 100644 index 0000000..f26f752 --- /dev/null +++ b/go/defaults_test.go @@ -0,0 +1,129 @@ +package acp + +import ( + "encoding/json" + "testing" +) + +// Ensure InitializeResponse.authMethods encodes to [] when nil or empty, +// and decodes to [] when missing or null. +func TestInitializeResponse_AuthMethods_Defaults(t *testing.T) { + t.Parallel() + t.Run("marshal_nil_slice_encodes_empty_array", func(t *testing.T) { + t.Parallel() + resp := InitializeResponse{ProtocolVersion: 1} + b, err := json.Marshal(resp) + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var m map[string]any + if err := json.Unmarshal(b, &m); err != nil { + t.Fatalf("roundtrip unmarshal error: %v", err) + } + v, ok := m["authMethods"] + if !ok { + t.Fatalf("authMethods missing in JSON: %s", string(b)) + } + arr, ok := v.([]any) + if !ok || len(arr) != 0 { + t.Fatalf("authMethods should be empty array; got: %#v (json=%s)", v, string(b)) + } + }) + + t.Run("marshal_empty_slice_encodes_empty_array", func(t *testing.T) { + t.Parallel() + resp := InitializeResponse{ProtocolVersion: 1, AuthMethods: []AuthMethod{}} + b, err := json.Marshal(resp) + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var m map[string]any + if err := json.Unmarshal(b, &m); err != nil { + t.Fatalf("roundtrip unmarshal error: %v", err) + } + v, ok := m["authMethods"] + if !ok { + t.Fatalf("authMethods missing in JSON: %s", string(b)) + } + arr, ok := v.([]any) + if !ok || len(arr) != 0 { + t.Fatalf("authMethods should be empty array; got: %#v (json=%s)", v, string(b)) + } + }) + + t.Run("unmarshal_missing_sets_empty_array", func(t *testing.T) { + t.Parallel() + var resp InitializeResponse + if err := json.Unmarshal([]byte(`{"protocolVersion":1}`), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.AuthMethods == nil || len(resp.AuthMethods) != 0 { + t.Fatalf("expected default empty authMethods; got: %#v", resp.AuthMethods) + } + }) + + t.Run("unmarshal_null_sets_empty_array", func(t *testing.T) { + t.Parallel() + var resp InitializeResponse + if err := json.Unmarshal([]byte(`{"protocolVersion":1, "authMethods": null}`), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.AuthMethods == nil || len(resp.AuthMethods) != 0 { + t.Fatalf("expected default empty authMethods on null; got: %#v", resp.AuthMethods) + } + }) +} + +// Ensure InitializeRequest.clientCapabilities defaults apply on decode when missing, +// and that the property is present on encode even when zero-value. +func TestInitializeRequest_ClientCapabilities_Defaults(t *testing.T) { + t.Parallel() + t.Run("unmarshal_missing_applies_defaults", func(t *testing.T) { + t.Parallel() + var req InitializeRequest + if err := json.Unmarshal([]byte(`{"protocolVersion":1}`), &req); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + // Defaults per schema: terminal=false; fs.readTextFile=false; fs.writeTextFile=false + if req.ClientCapabilities.Terminal != false || + req.ClientCapabilities.Fs.ReadTextFile != false || + req.ClientCapabilities.Fs.WriteTextFile != false { + t.Fatalf("unexpected clientCapabilities defaults: %+v", req.ClientCapabilities) + } + }) + + t.Run("marshal_zero_includes_property", func(t *testing.T) { + t.Parallel() + req := InitializeRequest{ProtocolVersion: 1} + b, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal error: %v", err) + } + var m map[string]any + if err := json.Unmarshal(b, &m); err != nil { + t.Fatalf("roundtrip unmarshal error: %v", err) + } + if _, ok := m["clientCapabilities"]; !ok { + t.Fatalf("clientCapabilities should be present in JSON: %s", string(b)) + } + }) +} + +// Ensure InitializeResponse.agentCapabilities defaults apply on decode when missing. +func TestInitializeResponse_AgentCapabilities_Defaults(t *testing.T) { + t.Parallel() + t.Run("unmarshal_missing_applies_defaults", func(t *testing.T) { + t.Parallel() + var resp InitializeResponse + if err := json.Unmarshal([]byte(`{"protocolVersion":1}`), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + // Defaults: loadSession=false; promptCapabilities audio=false, embeddedContext=false, image=false + if resp.AgentCapabilities.LoadSession != false || + resp.AgentCapabilities.PromptCapabilities.Audio != false || + resp.AgentCapabilities.PromptCapabilities.EmbeddedContext != false || + resp.AgentCapabilities.PromptCapabilities.Image != false { + t.Fatalf("unexpected agentCapabilities defaults: %+v", resp.AgentCapabilities) + } + }) +} diff --git a/go/example/agent/main.go b/go/example/agent/main.go index 44a5fcb..4138293 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -7,6 +7,8 @@ import ( "fmt" "io" "os" + "os/exec" + "os/signal" "time" acp "github.com/zed-industries/agent-client-protocol/go" @@ -254,11 +256,37 @@ func pause(ctx context.Context, d time.Duration) error { } func main() { - // Wire up stdio: write to stdout, read from stdin + // If args provided, treat them as client program + args to spawn and connect via stdio. + // Otherwise, default to stdio (allowing manual wiring or use by another process). + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + defer cancel() + + var ( + out io.Writer = os.Stdout + in io.Reader = os.Stdin + cmd *exec.Cmd + ) + if len(os.Args) > 1 { + cmd = exec.CommandContext(ctx, os.Args[1], os.Args[2:]...) + cmd.Stderr = os.Stderr + stdin, _ := cmd.StdinPipe() + stdout, _ := cmd.StdoutPipe() + if err := cmd.Start(); err != nil { + fmt.Fprintf(os.Stderr, "failed to start client: %v\n", err) + os.Exit(1) + } + out = stdin + in = stdout + } + ag := newExampleAgent() - asc := acp.NewAgentSideConnection(ag, os.Stdout, os.Stdin) + asc := acp.NewAgentSideConnection(ag, out, in) ag.SetAgentConnection(asc) - // Block until the peer disconnects (stdin closes). + // Block until the peer disconnects. <-asc.Done() + + if cmd != nil && cmd.Process != nil { + _ = cmd.Process.Kill() + } } diff --git a/go/example/client/main.go b/go/example/client/main.go index 44fb982..be85a4d 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -8,6 +8,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "strings" acp "github.com/zed-industries/agent-client-protocol/go" @@ -161,12 +162,25 @@ func main() { if len(os.Args) > 1 { cmd = exec.CommandContext(ctx, os.Args[1], os.Args[2:]...) } else { - // Assumes running from the go/ directory; if not, adjust path accordingly. - cmd = exec.CommandContext(ctx, "go", "run", "./example/agent") + // Default: run the Go example agent. Detect relative to this client's location. + _, filename, _, ok := runtime.Caller(0) + if !ok { + fmt.Fprintf(os.Stderr, "failed to determine current file location\n") + os.Exit(1) + } + + // Get directory of this client file and find sibling agent directory + clientDir := filepath.Dir(filename) + agentPath := filepath.Join(clientDir, "..", "agent") + + if _, err := os.Stat(agentPath); err != nil { + fmt.Fprintf(os.Stderr, "failed to find agent directory at %s: %v\n", agentPath, err) + os.Exit(1) + } + + cmd = exec.CommandContext(ctx, "go", "run", agentPath) } cmd.Stderr = os.Stderr - cmd.Stdout = nil - cmd.Stdin = nil // Set up pipes for stdio stdin, _ := cmd.StdinPipe() stdout, _ := cmd.StdoutPipe() diff --git a/go/go.mod b/go/go.mod deleted file mode 100644 index 0daf94a..0000000 --- a/go/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module github.com/zed-industries/agent-client-protocol/go - -go 1.21 diff --git a/go/json_parity_test.go b/go/json_parity_test.go index ed1f620..5918a2c 100644 --- a/go/json_parity_test.go +++ b/go/json_parity_test.go @@ -38,6 +38,7 @@ func mustReadGolden(t *testing.T, name string) []byte { func runGolden[T any](builds ...func() T) func(t *testing.T) { return func(t *testing.T) { t.Helper() + t.Parallel() // Use the current subtest name; expect pattern like "/". name := t.Name() base := name @@ -71,6 +72,7 @@ func runGolden[T any](builds ...func() T) func(t *testing.T) { } func TestJSONGolden_ContentBlocks(t *testing.T) { + t.Parallel() t.Run("content_text", runGolden( func() ContentBlock { return TextBlock("What's the weather like today?") }, )) @@ -110,6 +112,7 @@ func TestJSONGolden_ContentBlocks(t *testing.T) { } func TestJSONGolden_ToolCallContent(t *testing.T) { + t.Parallel() t.Run("tool_content_content_text", runGolden( func() ToolCallContent { return ToolContent(TextBlock("Analysis complete. Found 3 issues.")) }, )) @@ -128,6 +131,7 @@ func TestJSONGolden_ToolCallContent(t *testing.T) { } func TestJSONGolden_RequestPermissionOutcome(t *testing.T) { + t.Parallel() t.Run("permission_outcome_selected", runGolden( func() RequestPermissionOutcome { return RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{Outcome: "selected", OptionId: "allow-once"}} @@ -145,6 +149,7 @@ func TestJSONGolden_RequestPermissionOutcome(t *testing.T) { } func TestJSONGolden_SessionUpdates(t *testing.T) { + t.Parallel() t.Run("session_update_user_message_chunk", runGolden( func() SessionUpdate { return SessionUpdate{UserMessageChunk: &SessionUpdateUserMessageChunk{Content: TextBlock("What's the capital of France?")}} @@ -182,7 +187,6 @@ func TestJSONGolden_SessionUpdates(t *testing.T) { return StartToolCall("call_001", "Reading configuration file", WithStartKind(ToolKindRead), WithStartStatus(ToolCallStatusPending)) }, )) - // Removed: session_update_tool_call_minimal (deprecated New helper) t.Run("session_update_tool_call_read", runGolden( func() SessionUpdate { return StartReadToolCall("call_001", "Reading configuration file", "/home/user/project/src/config.json") @@ -206,7 +210,6 @@ func TestJSONGolden_SessionUpdates(t *testing.T) { return UpdateToolCall("call_001", WithUpdateStatus(ToolCallStatusInProgress), WithUpdateContent([]ToolCallContent{ToolContent(TextBlock("Found 3 configuration files..."))})) }, )) - // Removed: session_update_tool_call_update_minimal (deprecated New helper) t.Run("session_update_tool_call_update_more_fields", runGolden( func() SessionUpdate { return UpdateToolCall( @@ -224,11 +227,12 @@ func TestJSONGolden_SessionUpdates(t *testing.T) { } func TestJSONGolden_MethodPayloads(t *testing.T) { + t.Parallel() t.Run("initialize_request", runGolden(func() InitializeRequest { return InitializeRequest{ProtocolVersion: 1, ClientCapabilities: ClientCapabilities{Fs: FileSystemCapability{ReadTextFile: true, WriteTextFile: true}}} })) t.Run("initialize_response", runGolden(func() InitializeResponse { - return InitializeResponse{ProtocolVersion: 1, AgentCapabilities: AgentCapabilities{LoadSession: true, PromptCapabilities: PromptCapabilities{Image: true, Audio: true, EmbeddedContext: true}}, AuthMethods: []AuthMethod{}} + return InitializeResponse{ProtocolVersion: 1, AgentCapabilities: AgentCapabilities{LoadSession: true, PromptCapabilities: PromptCapabilities{Image: true, Audio: true, EmbeddedContext: true}}} })) t.Run("new_session_request", runGolden(func() NewSessionRequest { return NewSessionRequest{Cwd: "/home/user/project", McpServers: []McpServer{{Name: "filesystem", Command: "/path/to/mcp-server", Args: []string{"--stdio"}, Env: []EnvVariable{}}}} diff --git a/go/types_gen.go b/go/types_gen.go index d15416f..bdb9e7c 100644 --- a/go/types_gen.go +++ b/go/types_gen.go @@ -12,9 +12,46 @@ import ( // Capabilities supported by the agent. Advertised during initialization to inform the client about available features and content types. See protocol docs: [Agent Capabilities](https://agentclientprotocol.com/protocol/initialization#agent-capabilities) type AgentCapabilities struct { // Whether the agent supports 'session/load'. + // + // Defaults to false if unset. LoadSession bool `json:"loadSession,omitempty"` // Prompt capabilities supported by the agent. - PromptCapabilities PromptCapabilities `json:"promptCapabilities,omitempty"` + // + // Defaults to {"audio":false,"embeddedContext":false,"image":false} if unset. + PromptCapabilities PromptCapabilities `json:"promptCapabilities"` +} + +func (v AgentCapabilities) MarshalJSON() ([]byte, error) { + type Alias AgentCapabilities + var a Alias + a = Alias(v) + return json.Marshal(a) +} + +func (v *AgentCapabilities) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + type Alias AgentCapabilities + var a Alias + if err := json.Unmarshal(b, &a); err != nil { + return err + } + { + _rm, _ok := m["loadSession"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("false"), &a.LoadSession) + } + } + { + _rm, _ok := m["promptCapabilities"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("{\"audio\":false,\"embeddedContext\":false,\"image\":false}"), &a.PromptCapabilities) + } + } + *v = AgentCapabilities(a) + return nil } // All possible notifications that an agent can send to a client. This enum is used internally for routing RPC notifications. You typically won't need to use this directly - use the notification methods on the ['Client'] trait instead. Notifications do not expect a response. @@ -441,11 +478,48 @@ func (v *CancelNotification) Validate() error { // Capabilities supported by the client. Advertised during initialization to inform the agent about available features and methods. See protocol docs: [Client Capabilities](https://agentclientprotocol.com/protocol/initialization#client-capabilities) type ClientCapabilities struct { // File system capabilities supported by the client. Determines which file operations the agent can request. - Fs FileSystemCapability `json:"fs,omitempty"` + // + // Defaults to {"readTextFile":false,"writeTextFile":false} if unset. + Fs FileSystemCapability `json:"fs"` // **UNSTABLE** This capability is not part of the spec yet, and may be removed or changed at any point. + // + // Defaults to false if unset. Terminal bool `json:"terminal,omitempty"` } +func (v ClientCapabilities) MarshalJSON() ([]byte, error) { + type Alias ClientCapabilities + var a Alias + a = Alias(v) + return json.Marshal(a) +} + +func (v *ClientCapabilities) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + type Alias ClientCapabilities + var a Alias + if err := json.Unmarshal(b, &a); err != nil { + return err + } + { + _rm, _ok := m["fs"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("{\"readTextFile\":false,\"writeTextFile\":false}"), &a.Fs) + } + } + { + _rm, _ok := m["terminal"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("false"), &a.Terminal) + } + } + *v = ClientCapabilities(a) + return nil +} + // All possible notifications that a client can send to an agent. This enum is used internally for routing RPC notifications. You typically won't need to use this directly - use the notification methods on the ['Agent'] trait instead. Notifications do not expect a response. type ClientNotification struct { CancelNotification *CancelNotification `json:"-"` @@ -1223,11 +1297,48 @@ type EnvVariable struct { // File system capabilities that a client may support. See protocol docs: [FileSystem](https://agentclientprotocol.com/protocol/initialization#filesystem) type FileSystemCapability struct { // Whether the Client supports 'fs/read_text_file' requests. + // + // Defaults to false if unset. ReadTextFile bool `json:"readTextFile,omitempty"` // Whether the Client supports 'fs/write_text_file' requests. + // + // Defaults to false if unset. WriteTextFile bool `json:"writeTextFile,omitempty"` } +func (v FileSystemCapability) MarshalJSON() ([]byte, error) { + type Alias FileSystemCapability + var a Alias + a = Alias(v) + return json.Marshal(a) +} + +func (v *FileSystemCapability) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + type Alias FileSystemCapability + var a Alias + if err := json.Unmarshal(b, &a); err != nil { + return err + } + { + _rm, _ok := m["readTextFile"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("false"), &a.ReadTextFile) + } + } + { + _rm, _ok := m["writeTextFile"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("false"), &a.WriteTextFile) + } + } + *v = FileSystemCapability(a) + return nil +} + // An image provided to or from an LLM. type ImageContent struct { Annotations *Annotations `json:"annotations,omitempty"` @@ -1239,11 +1350,40 @@ type ImageContent struct { // Request parameters for the initialize method. Sent by the client to establish connection and negotiate capabilities. See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) type InitializeRequest struct { // Capabilities supported by the client. - ClientCapabilities ClientCapabilities `json:"clientCapabilities,omitempty"` + // + // Defaults to {"fs":{"readTextFile":false,"writeTextFile":false},"terminal":false} if unset. + ClientCapabilities ClientCapabilities `json:"clientCapabilities"` // The latest protocol version supported by the client. ProtocolVersion ProtocolVersion `json:"protocolVersion"` } +func (v InitializeRequest) MarshalJSON() ([]byte, error) { + type Alias InitializeRequest + var a Alias + a = Alias(v) + return json.Marshal(a) +} + +func (v *InitializeRequest) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + type Alias InitializeRequest + var a Alias + if err := json.Unmarshal(b, &a); err != nil { + return err + } + { + _rm, _ok := m["clientCapabilities"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("{\"fs\":{\"readTextFile\":false,\"writeTextFile\":false},\"terminal\":false}"), &a.ClientCapabilities) + } + } + *v = InitializeRequest(a) + return nil +} + func (v *InitializeRequest) Validate() error { return nil } @@ -1251,13 +1391,53 @@ func (v *InitializeRequest) Validate() error { // Response from the initialize method. Contains the negotiated protocol version and agent capabilities. See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) type InitializeResponse struct { // Capabilities supported by the agent. - AgentCapabilities AgentCapabilities `json:"agentCapabilities,omitempty"` + // + // Defaults to {"loadSession":false,"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false}} if unset. + AgentCapabilities AgentCapabilities `json:"agentCapabilities"` // Authentication methods supported by the agent. + // + // Defaults to [] if unset. AuthMethods []AuthMethod `json:"authMethods"` // The protocol version the client specified if supported by the agent, or the latest protocol version supported by the agent. The client should disconnect, if it doesn't support this version. ProtocolVersion ProtocolVersion `json:"protocolVersion"` } +func (v InitializeResponse) MarshalJSON() ([]byte, error) { + type Alias InitializeResponse + var a Alias + a = Alias(v) + if a.AuthMethods == nil { + json.Unmarshal([]byte("[]"), &a.AuthMethods) + } + return json.Marshal(a) +} + +func (v *InitializeResponse) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + type Alias InitializeResponse + var a Alias + if err := json.Unmarshal(b, &a); err != nil { + return err + } + { + _rm, _ok := m["agentCapabilities"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("{\"loadSession\":false,\"promptCapabilities\":{\"audio\":false,\"embeddedContext\":false,\"image\":false}}"), &a.AgentCapabilities) + } + } + { + _rm, _ok := m["authMethods"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("[]"), &a.AuthMethods) + } + } + *v = InitializeResponse(a) + return nil +} + func (v *InitializeResponse) Validate() error { return nil } @@ -1396,13 +1576,58 @@ const ( // Prompt capabilities supported by the agent in 'session/prompt' requests. Baseline agent functionality requires support for ['ContentBlock::Text'] and ['ContentBlock::ResourceLink'] in prompt requests. Other variants must be explicitly opted in to. Capabilities for different types of content in prompt requests. Indicates which content types beyond the baseline (text and resource links) the agent can process. See protocol docs: [Prompt Capabilities](https://agentclientprotocol.com/protocol/initialization#prompt-capabilities) type PromptCapabilities struct { // Agent supports ['ContentBlock::Audio']. + // + // Defaults to false if unset. Audio bool `json:"audio,omitempty"` // Agent supports embedded context in 'session/prompt' requests. When enabled, the Client is allowed to include ['ContentBlock::Resource'] in prompt requests for pieces of context that are referenced in the message. + // + // Defaults to false if unset. EmbeddedContext bool `json:"embeddedContext,omitempty"` // Agent supports ['ContentBlock::Image']. + // + // Defaults to false if unset. Image bool `json:"image,omitempty"` } +func (v PromptCapabilities) MarshalJSON() ([]byte, error) { + type Alias PromptCapabilities + var a Alias + a = Alias(v) + return json.Marshal(a) +} + +func (v *PromptCapabilities) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + type Alias PromptCapabilities + var a Alias + if err := json.Unmarshal(b, &a); err != nil { + return err + } + { + _rm, _ok := m["audio"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("false"), &a.Audio) + } + } + { + _rm, _ok := m["embeddedContext"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("false"), &a.EmbeddedContext) + } + } + { + _rm, _ok := m["image"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("false"), &a.Image) + } + } + *v = PromptCapabilities(a) + return nil +} + // Request parameters for sending a user prompt to the agent. Contains the user's message and any additional context. See protocol docs: [User Message](https://agentclientprotocol.com/protocol/prompt-turn#1-user-message) type PromptRequest struct { // The blocks of content that compose the user's message. As a baseline, the Agent MUST support ['ContentBlock::Text'] and ['ContentBlock::ResourceLink'], while other variants are optionally enabled via ['PromptCapabilities']. The Client MUST adapt its interface according to ['PromptCapabilities']. The client MAY include referenced pieces of context as either ['ContentBlock::Resource'] or ['ContentBlock::ResourceLink']. When available, ['ContentBlock::Resource'] is preferred as it avoids extra round-trips and allows the message to include pieces of context from sources the agent may not have access to. diff --git a/package.json b/package.json index a3e5263..9c1de84 100644 --- a/package.json +++ b/package.json @@ -31,7 +31,9 @@ "prepublishOnly": "cp typescript/README.md README.md", "postpublish": "git checkout README.md", "clean": "rm -rf dist typescript/*.js typescript/*.d.ts typescript/*.js.map tsconfig.tsbuildinfo && cargo clean", - "test": "cargo check --all-targets && cargo test && vitest run", + "test": "cargo check --all-targets && cargo test && vitest run && npm run test:go", + "test:go": "cd go && go test ./...", + "test:go:race": "cd go && go test -race ./...", "test:ts": "vitest run", "test:ts:watch": "vitest", "generate:json-schema": "cd rust && cargo run --bin generate --features unstable", @@ -39,9 +41,9 @@ "generate:go": "cd go/cmd/generate && env -u GOPATH -u GOMODCACHE go run . && cd ../.. && env -u GOPATH -u GOMODCACHE go run mvdan.cc/gofumpt@latest -w .", "generate": "npm run generate:json-schema && npm run generate:ts-schema && npm run generate:go && npm run format", "build": "npm run generate && tsc", - "format": "prettier --write . && cargo fmt", - "format:check": "prettier --check . && cargo fmt -- --check", - "lint": "cargo clippy", + "format": "prettier --write . && cargo fmt && (cd go && env -u GOPATH -u GOMODCACHE go run mvdan.cc/gofumpt@latest -w .)", + "format:check": "prettier --check . && cargo fmt -- --check && (cd go && test -z \"$(env -u GOPATH -u GOMODCACHE go run mvdan.cc/gofumpt@latest -l .)\" || (echo 'gofumpt: found unformatted Go files' >&2; exit 1))", + "lint": "cargo clippy && (cd go && go vet ./...)", "lint:fix": "cargo clippy --fix", "check:go": "cd go && go build ./...", "check": "npm run lint && npm run format:check && npm run build && npm run test && npm run docs:ts:verify && npm run check:go", diff --git a/rust/markdown_generator.rs b/rust/markdown_generator.rs index 398733f..9355306 100644 --- a/rust/markdown_generator.rs +++ b/rust/markdown_generator.rs @@ -655,9 +655,11 @@ impl SideDocs { } fn extract_side_docs() -> SideDocs { - let output = Command::new("cargo") + // Try to run cargo rustdoc with the current toolchain first (works with rustup via rust-toolchain.toml + // and with Nix-provided nightly toolchains). If that fails, fall back to the rustup-style '+nightly' + // invocation for environments where a default stable toolchain is active. + let mut output = Command::new("cargo") .args([ - "+nightly", "rustdoc", "--lib", "--", @@ -670,10 +672,28 @@ fn extract_side_docs() -> SideDocs { .unwrap(); if !output.status.success() { - panic!( - "Failed to generate rustdoc JSON: {}", - String::from_utf8_lossy(&output.stderr) - ); + let fallback = Command::new("cargo") + .args([ + "+nightly", + "rustdoc", + "--lib", + "--", + "-Z", + "unstable-options", + "--output-format", + "json", + ]) + .output() + .unwrap(); + + if !fallback.status.success() { + panic!( + "Failed to generate rustdoc JSON. First attempt (no +nightly): {}\nFallback (+nightly) failed: {}", + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&fallback.stderr) + ); + } + output = fallback; } // Parse the JSON output From 630c0dc4790c3277e1975e1ae29f794dd86a09b6 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Wed, 10 Sep 2025 14:50:02 +0200 Subject: [PATCH 19/22] refactor: improve connection handling with context-based cancellation and reduced code duplication Change-Id: I808b16216e09a94eca58cd8e95c99b3bb927596f Signed-off-by: Thomas Kosiewski --- go/connection.go | 231 ++++++++++++++++++++------------------- go/example/agent/main.go | 26 ++++- 2 files changed, 139 insertions(+), 118 deletions(-) diff --git a/go/connection.go b/go/connection.go index 1898b71..958fa68 100644 --- a/go/connection.go +++ b/go/connection.go @@ -2,8 +2,10 @@ package acp import ( "bufio" + "bytes" "context" "encoding/json" + "errors" "io" "sync" "sync/atomic" @@ -34,71 +36,72 @@ type Connection struct { nextID atomic.Uint64 pending map[string]*pendingResponse - done chan struct{} + ctx context.Context + cancel context.CancelCauseFunc } func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection { + ctx, cancel := context.WithCancelCause(context.Background()) c := &Connection{ w: peerInput, r: peerOutput, handler: handler, pending: make(map[string]*pendingResponse), - done: make(chan struct{}), + ctx: ctx, + cancel: cancel, } go c.receive() return c } func (c *Connection) receive() { + const ( + initialBufSize = 1024 * 1024 + maxBufSize = 10 * 1024 * 1024 + ) + scanner := bufio.NewScanner(c.r) - // increase buffer if needed - buf := make([]byte, 0, 1024*1024) - scanner.Buffer(buf, 10*1024*1024) + buf := make([]byte, 0, initialBufSize) + scanner.Buffer(buf, maxBufSize) + for scanner.Scan() { line := scanner.Bytes() - if len(bytesTrimSpace(line)) == 0 { + if len(bytes.TrimSpace(line)) == 0 { continue } + var msg anyMessage if err := json.Unmarshal(line, &msg); err != nil { - // ignore parse errors on inbound continue } - if msg.ID != nil && msg.Method == "" { - // response - idStr := string(*msg.ID) - c.mu.Lock() - pr := c.pending[idStr] - if pr != nil { - delete(c.pending, idStr) - } - c.mu.Unlock() - if pr != nil { - pr.ch <- msg - } - continue - } - if msg.Method != "" { - // request or notification + + switch { + case msg.ID != nil && msg.Method == "": + c.handleResponse(&msg) + case msg.Method != "": go c.handleInbound(&msg) } } - // Signal completion on EOF or read error + + c.cancel(errors.New("peer connection closed")) +} + +func (c *Connection) handleResponse(msg *anyMessage) { + idStr := string(*msg.ID) + c.mu.Lock() - if c.done != nil { - close(c.done) - c.done = nil + pr := c.pending[idStr] + if pr != nil { + delete(c.pending, idStr) } c.mu.Unlock() + + if pr != nil { + pr.ch <- *msg + } } func (c *Connection) handleInbound(req *anyMessage) { - // Context that cancels when the connection is closed - ctx, cancel := context.WithCancel(context.Background()) - go func() { - <-c.Done() - cancel() - }() res := anyMessage{JSONRPC: "2.0"} // copy ID if present if req.ID != nil { @@ -112,7 +115,7 @@ func (c *Connection) handleInbound(req *anyMessage) { return } - result, err := c.handler(ctx, req.Method, req.Params) + result, err := c.handler(c.ctx, req.Method, req.Params) if req.ID == nil { // notification: nothing to send return @@ -147,93 +150,101 @@ func (c *Connection) sendMessage(msg anyMessage) error { // SendRequest sends a JSON-RPC request and returns a typed result. // For methods that do not return a result, use SendRequestNoResult instead. func SendRequest[T any](c *Connection, ctx context.Context, method string, params any) (T, error) { - var zero T - // allocate id - id := c.nextID.Add(1) - idRaw, _ := json.Marshal(id) - msg := anyMessage{ - JSONRPC: "2.0", - ID: (*json.RawMessage)(&idRaw), - Method: method, - } - if params != nil { - b, err := json.Marshal(params) - if err != nil { - return zero, NewInvalidParams(map[string]any{"error": err.Error()}) - } - msg.Params = b + var result T + + msg, idKey, err := c.prepareRequest(method, params) + if err != nil { + return result, err } + pr := &pendingResponse{ch: make(chan anyMessage, 1)} - idKey := string(idRaw) c.mu.Lock() c.pending[idKey] = pr c.mu.Unlock() + if err := c.sendMessage(msg); err != nil { - return zero, NewInternalError(map[string]any{"error": err.Error()}) + c.cleanupPending(idKey) + return result, NewInternalError(map[string]any{"error": err.Error()}) } - // wait for response or peer disconnect - var resp anyMessage - d := c.Done() - select { - case resp = <-pr.ch: - case <-ctx.Done(): - // best-effort cleanup - c.mu.Lock() - delete(c.pending, idKey) - c.mu.Unlock() - return zero, NewInternalError(map[string]any{"error": ctx.Err().Error()}) - case <-d: - return zero, NewInternalError(map[string]any{"error": "peer disconnected before response"}) + + resp, err := c.waitForResponse(ctx, pr, idKey) + if err != nil { + return result, err } + if resp.Error != nil { - return zero, resp.Error + return result, resp.Error } - var out T + if len(resp.Result) > 0 { - if err := json.Unmarshal(resp.Result, &out); err != nil { - return zero, NewInternalError(map[string]any{"error": err.Error()}) + if err := json.Unmarshal(resp.Result, &result); err != nil { + return result, NewInternalError(map[string]any{"error": err.Error()}) } } - return out, nil + return result, nil } -// SendRequestNoResult sends a JSON-RPC request that returns no result payload. -func (c *Connection) SendRequestNoResult(ctx context.Context, method string, params any) error { - // allocate id +func (c *Connection) prepareRequest(method string, params any) (anyMessage, string, error) { id := c.nextID.Add(1) idRaw, _ := json.Marshal(id) + msg := anyMessage{ JSONRPC: "2.0", ID: (*json.RawMessage)(&idRaw), Method: method, } + if params != nil { b, err := json.Marshal(params) if err != nil { - return NewInvalidParams(map[string]any{"error": err.Error()}) + return msg, "", NewInvalidParams(map[string]any{"error": err.Error()}) } msg.Params = b } + + return msg, string(idRaw), nil +} + +func (c *Connection) waitForResponse(ctx context.Context, pr *pendingResponse, idKey string) (anyMessage, error) { + select { + case resp := <-pr.ch: + return resp, nil + case <-ctx.Done(): + c.cleanupPending(idKey) + return anyMessage{}, NewInternalError(map[string]any{"error": context.Cause(ctx).Error()}) + case <-c.Done(): + return anyMessage{}, NewInternalError(map[string]any{"error": "peer disconnected before response"}) + } +} + +func (c *Connection) cleanupPending(idKey string) { + c.mu.Lock() + delete(c.pending, idKey) + c.mu.Unlock() +} + +// SendRequestNoResult sends a JSON-RPC request that returns no result payload. +func (c *Connection) SendRequestNoResult(ctx context.Context, method string, params any) error { + msg, idKey, err := c.prepareRequest(method, params) + if err != nil { + return err + } + pr := &pendingResponse{ch: make(chan anyMessage, 1)} - idKey := string(idRaw) c.mu.Lock() c.pending[idKey] = pr c.mu.Unlock() + if err := c.sendMessage(msg); err != nil { + c.cleanupPending(idKey) return NewInternalError(map[string]any{"error": err.Error()}) } - var resp anyMessage - d := c.Done() - select { - case resp = <-pr.ch: - case <-ctx.Done(): - c.mu.Lock() - delete(c.pending, idKey) - c.mu.Unlock() - return NewInternalError(map[string]any{"error": ctx.Err().Error()}) - case <-d: - return NewInternalError(map[string]any{"error": "peer disconnected before response"}) + + resp, err := c.waitForResponse(ctx, pr, idKey) + if err != nil { + return err } + if resp.Error != nil { return resp.Error } @@ -246,43 +257,37 @@ func (c *Connection) SendNotification(ctx context.Context, method string, params return NewInternalError(map[string]any{"error": ctx.Err().Error()}) default: } - msg := anyMessage{JSONRPC: "2.0", Method: method} + + msg, err := c.prepareNotification(method, params) + if err != nil { + return err + } + + if err := c.sendMessage(msg); err != nil { + return NewInternalError(map[string]any{"error": err.Error()}) + } + return nil +} + +func (c *Connection) prepareNotification(method string, params any) (anyMessage, error) { + msg := anyMessage{ + JSONRPC: "2.0", + Method: method, + } + if params != nil { b, err := json.Marshal(params) if err != nil { - return NewInvalidParams(map[string]any{"error": err.Error()}) + return msg, NewInvalidParams(map[string]any{"error": err.Error()}) } msg.Params = b } - if err := c.sendMessage(msg); err != nil { - return NewInternalError(map[string]any{"error": err.Error()}) - } - return nil + + return msg, nil } // Done returns a channel that is closed when the underlying reader loop exits // (typically when the peer disconnects or the input stream is closed). func (c *Connection) Done() <-chan struct{} { - c.mu.Lock() - d := c.done - c.mu.Unlock() - return d -} - -// Helper: lightweight TrimSpace for []byte without importing bytes only for this. -func bytesTrimSpace(b []byte) []byte { - i := 0 - for ; i < len(b); i++ { - if b[i] != ' ' && b[i] != '\t' && b[i] != '\r' && b[i] != '\n' { - break - } - } - j := len(b) - for j > i { - if b[j-1] != ' ' && b[j-1] != '\t' && b[j-1] != '\r' && b[j-1] != '\n' { - break - } - j-- - } - return b[i:j] + return c.ctx.Done() } diff --git a/go/example/agent/main.go b/go/example/agent/main.go index 4138293..dbb2939 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "os/signal" + "sync" "time" acp "github.com/zed-industries/agent-client-protocol/go" @@ -21,6 +22,7 @@ type agentSession struct { type exampleAgent struct { conn *acp.AgentSideConnection sessions map[string]*agentSession + mu sync.Mutex } var ( @@ -46,7 +48,9 @@ func (a *exampleAgent) Initialize(ctx context.Context, params acp.InitializeRequ func (a *exampleAgent) NewSession(ctx context.Context, params acp.NewSessionRequest) (acp.NewSessionResponse, error) { sid := randomID() + a.mu.Lock() a.sessions[sid] = &agentSession{} + a.mu.Unlock() return acp.NewSessionResponse{SessionId: acp.SessionId(sid)}, nil } @@ -55,27 +59,37 @@ func (a *exampleAgent) Authenticate(ctx context.Context, _ acp.AuthenticateReque func (a *exampleAgent) LoadSession(ctx context.Context, _ acp.LoadSessionRequest) error { return nil } func (a *exampleAgent) Cancel(ctx context.Context, params acp.CancelNotification) error { - if s, ok := a.sessions[string(params.SessionId)]; ok { - if s.cancel != nil { - s.cancel() - } + a.mu.Lock() + s, ok := a.sessions[string(params.SessionId)] + a.mu.Unlock() + if ok && s != nil && s.cancel != nil { + s.cancel() } return nil } func (a *exampleAgent) Prompt(_ context.Context, params acp.PromptRequest) (acp.PromptResponse, error) { sid := string(params.SessionId) + a.mu.Lock() s, ok := a.sessions[sid] + a.mu.Unlock() if !ok { return acp.PromptResponse{}, fmt.Errorf("session %s not found", sid) } // cancel any previous turn + a.mu.Lock() if s.cancel != nil { - s.cancel() + prev := s.cancel + a.mu.Unlock() + prev() + } else { + a.mu.Unlock() } ctx, cancel := context.WithCancel(context.Background()) + a.mu.Lock() s.cancel = cancel + a.mu.Unlock() // simulate a full turn with streaming updates and a permission request if err := a.simulateTurn(ctx, sid); err != nil { @@ -84,7 +98,9 @@ func (a *exampleAgent) Prompt(_ context.Context, params acp.PromptRequest) (acp. } return acp.PromptResponse{}, err } + a.mu.Lock() s.cancel = nil + a.mu.Unlock() return acp.PromptResponse{StopReason: acp.StopReasonEndTurn}, nil } From 422dc568ea855e43fb90b7576ef9774e72c8e4fd Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Wed, 10 Sep 2025 15:12:00 +0200 Subject: [PATCH 20/22] feat: add session modes and MCP capabilities support Change-Id: I5db6760df4ae4f89ebacc3f1cc83f15d95afd0b9 Signed-off-by: Thomas Kosiewski --- .gitignore | 2 + flake.nix | 33 ++- go/agent_gen.go | 5 +- go/client_gen.go | 5 +- go/constants_gen.go | 13 +- go/types_gen.go | 515 ++++++++++++++++++++++++++++++++++--- rust/markdown_generator.rs | 32 +-- 7 files changed, 529 insertions(+), 76 deletions(-) diff --git a/.gitignore b/.gitignore index 42c9116..806bc05 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,5 @@ typescript/docs/ .envrc .direnv +.rustup/ +.cargo/ diff --git a/flake.nix b/flake.nix index 094b83c..2585f91 100644 --- a/flake.nix +++ b/flake.nix @@ -24,17 +24,16 @@ formatter = pkgs.nixfmt-rfc-style; - # Rust toolchain derived from rust-toolchain.toml - # Uses oxalica/rust-overlay to match the pinned channel/components. - rustToolchain = pkgs.rust-bin.fromRustupToolchainFile ./rust-toolchain.toml; + # Use rustup to manage toolchains so `cargo +nightly` works in dev shell + rustup = pkgs.rustup; in { inherit formatter; devShells.default = pkgs.mkShell { packages = with pkgs; [ - # Rust toolchain pinned via rust-toolchain.toml - rustToolchain + # Rustup manages stable/nightly toolchains according to rust-toolchain.toml + rustup pkg-config openssl @@ -49,6 +48,30 @@ ]; RUST_BACKTRACE = "1"; + + # Ensure rustup shims are used and install required toolchains on first entry + shellHook = '' + export RUSTUP_HOME="$PWD/.rustup" + export CARGO_HOME="$PWD/.cargo" + export PATH="$CARGO_HOME/bin:$PATH" + + if ! command -v rustup >/dev/null 2>&1; then + echo "rustup not found in PATH" 1>&2 + else + # Install toolchains if missing; respect pinned channel from rust-toolchain.toml + if ! rustup toolchain list | grep -q nightly; then + rustup toolchain install nightly --profile minimal >/dev/null 2>&1 || true + fi + # Ensure stable toolchain from rust-toolchain.toml exists (rustup will auto-select it) + # Attempt to install channel specified in rust-toolchain.toml (fallback to stable) + TOOLCHAIN_CHANNEL=$(sed -n 's/^channel\s*=\s*"\(.*\)"/\1/p' rust-toolchain.toml || true) + if [ -n "$TOOLCHAIN_CHANNEL" ]; then + if ! rustup toolchain list | grep -q "$TOOLCHAIN_CHANNEL"; then + rustup toolchain install "$TOOLCHAIN_CHANNEL" --profile minimal --component rustfmt clippy >/dev/null 2>&1 || true + fi + fi + fi + ''; }; } ); diff --git a/go/agent_gen.go b/go/agent_gen.go index 2a80522..883fb21 100644 --- a/go/agent_gen.go +++ b/go/agent_gen.go @@ -64,10 +64,11 @@ func (a *AgentSideConnection) handle(ctx context.Context, method string, params if !ok { return nil, NewMethodNotFound(method) } - if err := loader.LoadSession(ctx, p); err != nil { + resp, err := loader.LoadSession(ctx, p) + if err != nil { return nil, toReqErr(err) } - return nil, nil + return resp, nil case AgentMethodSessionNew: var p NewSessionRequest if err := json.Unmarshal(params, &p); err != nil { diff --git a/go/client_gen.go b/go/client_gen.go index ead5e0a..ced57c4 100644 --- a/go/client_gen.go +++ b/go/client_gen.go @@ -140,8 +140,9 @@ func (c *ClientSideConnection) Initialize(ctx context.Context, params Initialize func (c *ClientSideConnection) Cancel(ctx context.Context, params CancelNotification) error { return c.conn.SendNotification(ctx, AgentMethodSessionCancel, params) } -func (c *ClientSideConnection) LoadSession(ctx context.Context, params LoadSessionRequest) error { - return c.conn.SendRequestNoResult(ctx, AgentMethodSessionLoad, params) +func (c *ClientSideConnection) LoadSession(ctx context.Context, params LoadSessionRequest) (LoadSessionResponse, error) { + resp, err := SendRequest[LoadSessionResponse](c.conn, ctx, AgentMethodSessionLoad, params) + return resp, err } func (c *ClientSideConnection) NewSession(ctx context.Context, params NewSessionRequest) (NewSessionResponse, error) { resp, err := SendRequest[NewSessionResponse](c.conn, ctx, AgentMethodSessionNew, params) diff --git a/go/constants_gen.go b/go/constants_gen.go index 1e08f03..3c2b914 100644 --- a/go/constants_gen.go +++ b/go/constants_gen.go @@ -7,12 +7,13 @@ const ProtocolVersionNumber = 1 // Agent method names const ( - AgentMethodAuthenticate = "authenticate" - AgentMethodInitialize = "initialize" - AgentMethodSessionCancel = "session/cancel" - AgentMethodSessionLoad = "session/load" - AgentMethodSessionNew = "session/new" - AgentMethodSessionPrompt = "session/prompt" + AgentMethodAuthenticate = "authenticate" + AgentMethodInitialize = "initialize" + AgentMethodSessionCancel = "session/cancel" + AgentMethodSessionLoad = "session/load" + AgentMethodSessionNew = "session/new" + AgentMethodSessionPrompt = "session/prompt" + AgentMethodSessionSetMode = "session/set_mode" ) // Client method names diff --git a/go/types_gen.go b/go/types_gen.go index bdb9e7c..9abbbdd 100644 --- a/go/types_gen.go +++ b/go/types_gen.go @@ -15,6 +15,10 @@ type AgentCapabilities struct { // // Defaults to false if unset. LoadSession bool `json:"loadSession,omitempty"` + // MCP capabilities supported by the agent. + // + // Defaults to {"http":false,"sse":false} if unset. + McpCapabilities McpCapabilities `json:"mcpCapabilities"` // Prompt capabilities supported by the agent. // // Defaults to {"audio":false,"embeddedContext":false,"image":false} if unset. @@ -44,6 +48,12 @@ func (v *AgentCapabilities) UnmarshalJSON(b []byte) error { json.Unmarshal([]byte("false"), &a.LoadSession) } } + { + _rm, _ok := m["mcpCapabilities"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("{\"http\":false,\"sse\":false}"), &a.McpCapabilities) + } + } { _rm, _ok := m["promptCapabilities"] if !_ok || (string(_rm) == "null") { @@ -258,14 +268,13 @@ func (u AgentRequest) MarshalJSON() ([]byte, error) { // All possible responses that an agent can send to a client. This enum is used internally for routing RPC responses. You typically won't need to use this directly - the responses are handled automatically by the connection. These are responses to the corresponding ClientRequest variants. type AuthenticateResponse struct{} -type LoadSessionResponse struct{} - type AgentResponse struct { - InitializeResponse *InitializeResponse `json:"-"` - AuthenticateResponse *AuthenticateResponse `json:"-"` - NewSessionResponse *NewSessionResponse `json:"-"` - LoadSessionResponse *LoadSessionResponse `json:"-"` - PromptResponse *PromptResponse `json:"-"` + InitializeResponse *InitializeResponse `json:"-"` + AuthenticateResponse *AuthenticateResponse `json:"-"` + NewSessionResponse *NewSessionResponse `json:"-"` + LoadSessionResponse *LoadSessionResponse `json:"-"` + SetSessionModeResponse *SetSessionModeResponse `json:"-"` + PromptResponse *PromptResponse `json:"-"` } func (u *AgentResponse) UnmarshalJSON(b []byte) error { @@ -306,6 +315,13 @@ func (u *AgentResponse) UnmarshalJSON(b []byte) error { return nil } } + { + var v SetSessionModeResponse + if json.Unmarshal(b, &v) == nil { + u.SetSessionModeResponse = &v + return nil + } + } { var v PromptResponse if json.Unmarshal(b, &v) == nil { @@ -342,7 +358,26 @@ func (u AgentResponse) MarshalJSON() ([]byte, error) { return json.Marshal(m) } if u.LoadSessionResponse != nil { - return json.Marshal(nil) + var m map[string]any + _b, _e := json.Marshal(*u.LoadSessionResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + if u.SetSessionModeResponse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.SetSessionModeResponse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) } if u.PromptResponse != nil { var m map[string]any @@ -556,11 +591,12 @@ func (u ClientNotification) MarshalJSON() ([]byte, error) { // All possible requests that a client can send to an agent. This enum is used internally for routing RPC requests. You typically won't need to use this directly - instead, use the methods on the ['Agent'] trait. This enum encompasses all method calls from client to agent. type ClientRequest struct { - InitializeRequest *InitializeRequest `json:"-"` - AuthenticateRequest *AuthenticateRequest `json:"-"` - NewSessionRequest *NewSessionRequest `json:"-"` - LoadSessionRequest *LoadSessionRequest `json:"-"` - PromptRequest *PromptRequest `json:"-"` + InitializeRequest *InitializeRequest `json:"-"` + AuthenticateRequest *AuthenticateRequest `json:"-"` + NewSessionRequest *NewSessionRequest `json:"-"` + LoadSessionRequest *LoadSessionRequest `json:"-"` + SetSessionModeRequest *SetSessionModeRequest `json:"-"` + PromptRequest *PromptRequest `json:"-"` } func (u *ClientRequest) UnmarshalJSON(b []byte) error { @@ -596,6 +632,13 @@ func (u *ClientRequest) UnmarshalJSON(b []byte) error { return nil } } + { + var v SetSessionModeRequest + if json.Unmarshal(b, &v) == nil { + u.SetSessionModeRequest = &v + return nil + } + } { var v PromptRequest if json.Unmarshal(b, &v) == nil { @@ -650,6 +693,17 @@ func (u ClientRequest) MarshalJSON() ([]byte, error) { } return json.Marshal(m) } + if u.SetSessionModeRequest != nil { + var m map[string]any + _b, _e := json.Marshal(*u.SetSessionModeRequest) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } if u.PromptRequest != nil { var m map[string]any _b, _e := json.Marshal(*u.PromptRequest) @@ -1339,6 +1393,14 @@ func (v *FileSystemCapability) UnmarshalJSON(b []byte) error { return nil } +// An HTTP header to set when making requests to the MCP server. +type HttpHeader struct { + // The name of the HTTP header. + Name string `json:"name"` + // The value to set for the HTTP header. + Value string `json:"value"` +} + // An image provided to or from an LLM. type ImageContent struct { Annotations *Annotations `json:"annotations,omitempty"` @@ -1392,7 +1454,7 @@ func (v *InitializeRequest) Validate() error { type InitializeResponse struct { // Capabilities supported by the agent. // - // Defaults to {"loadSession":false,"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false}} if unset. + // Defaults to {"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false}} if unset. AgentCapabilities AgentCapabilities `json:"agentCapabilities"` // Authentication methods supported by the agent. // @@ -1425,7 +1487,7 @@ func (v *InitializeResponse) UnmarshalJSON(b []byte) error { { _rm, _ok := m["agentCapabilities"] if !_ok || (string(_rm) == "null") { - json.Unmarshal([]byte("{\"loadSession\":false,\"promptCapabilities\":{\"audio\":false,\"embeddedContext\":false,\"image\":false}}"), &a.AgentCapabilities) + json.Unmarshal([]byte("{\"loadSession\":false,\"mcpCapabilities\":{\"http\":false,\"sse\":false},\"promptCapabilities\":{\"audio\":false,\"embeddedContext\":false,\"image\":false}}"), &a.AgentCapabilities) } } { @@ -1454,7 +1516,7 @@ func (v *KillTerminalRequest) Validate() error { return nil } -// Request parameters for loading an existing session. Only available if the agent supports the 'loadSession' capability. See protocol docs: [Loading Sessions](https://agentclientprotocol.com/protocol/session-setup#loading-sessions) +// Request parameters for loading an existing session. Only available if the Agent supports the 'loadSession' capability. See protocol docs: [Loading Sessions](https://agentclientprotocol.com/protocol/session-setup#loading-sessions) type LoadSessionRequest struct { // The working directory for this session. Cwd string `json:"cwd"` @@ -1474,8 +1536,86 @@ func (v *LoadSessionRequest) Validate() error { return nil } +// Response from loading an existing session. +type LoadSessionResponse struct { + // **UNSTABLE** This field is not part of the spec, and may be removed or changed at any point. + Modes *SessionModeState `json:"modes,omitempty"` +} + +func (v *LoadSessionResponse) Validate() error { + return nil +} + +// MCP capabilities supported by the agent +type McpCapabilities struct { + // Agent supports ['McpServer::Http']. + // + // Defaults to false if unset. + Http bool `json:"http,omitempty"` + // Agent supports ['McpServer::Sse']. + // + // Defaults to false if unset. + Sse bool `json:"sse,omitempty"` +} + +func (v McpCapabilities) MarshalJSON() ([]byte, error) { + type Alias McpCapabilities + var a Alias + a = Alias(v) + return json.Marshal(a) +} + +func (v *McpCapabilities) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + type Alias McpCapabilities + var a Alias + if err := json.Unmarshal(b, &a); err != nil { + return err + } + { + _rm, _ok := m["http"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("false"), &a.Http) + } + } + { + _rm, _ok := m["sse"] + if !_ok || (string(_rm) == "null") { + json.Unmarshal([]byte("false"), &a.Sse) + } + } + *v = McpCapabilities(a) + return nil +} + // Configuration for connecting to an MCP (Model Context Protocol) server. MCP servers provide tools and context that the agent can use when processing prompts. See protocol docs: [MCP Servers](https://agentclientprotocol.com/protocol/session-setup#mcp-servers) -type McpServer struct { +// HTTP transport configuration Only available when the Agent capabilities indicate 'mcp_capabilities.http' is 'true'. +type McpServerHttp struct { + // HTTP headers to set when making requests to the MCP server. + Headers []HttpHeader `json:"headers"` + // Human-readable name identifying this MCP server. + Name string `json:"name"` + Type string `json:"type"` + // URL to the MCP server. + Url string `json:"url"` +} + +// SSE transport configuration Only available when the Agent capabilities indicate 'mcp_capabilities.sse' is 'true'. +type McpServerSse struct { + // HTTP headers to set when making requests to the MCP server. + Headers []HttpHeader `json:"headers"` + // Human-readable name identifying this MCP server. + Name string `json:"name"` + Type string `json:"type"` + // URL to the MCP server. + Url string `json:"url"` +} + +// Stdio transport configuration All Agents MUST support this transport. +type stdio struct { // Command-line arguments to pass to the MCP server. Args []string `json:"args"` // Path to the MCP server executable. @@ -1486,6 +1626,170 @@ type McpServer struct { Name string `json:"name"` } +type McpServer struct { + Http *McpServerHttp `json:"-"` + Sse *McpServerSse `json:"-"` + stdio *stdio `json:"-"` +} + +func (u *McpServer) UnmarshalJSON(b []byte) error { + var m map[string]json.RawMessage + if err := json.Unmarshal(b, &m); err != nil { + return err + } + { + var disc string + if v, ok := m["type"]; ok { + json.Unmarshal(v, &disc) + } + switch disc { + case "http": + var v McpServerHttp + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Http = &v + return nil + case "sse": + var v McpServerSse + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Sse = &v + return nil + } + } + { + var v McpServerHttp + var match bool = true + if _, ok := m["type"]; !ok { + match = false + } + if _, ok := m["name"]; !ok { + match = false + } + if _, ok := m["url"]; !ok { + match = false + } + if _, ok := m["headers"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Http = &v + return nil + } + } + { + var v McpServerSse + var match bool = true + if _, ok := m["type"]; !ok { + match = false + } + if _, ok := m["name"]; !ok { + match = false + } + if _, ok := m["url"]; !ok { + match = false + } + if _, ok := m["headers"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.Sse = &v + return nil + } + } + { + var v stdio + var match bool = true + if _, ok := m["name"]; !ok { + match = false + } + if _, ok := m["command"]; !ok { + match = false + } + if _, ok := m["args"]; !ok { + match = false + } + if _, ok := m["env"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.stdio = &v + return nil + } + } + { + var v McpServerHttp + if json.Unmarshal(b, &v) == nil { + u.Http = &v + return nil + } + } + { + var v McpServerSse + if json.Unmarshal(b, &v) == nil { + u.Sse = &v + return nil + } + } + { + var v stdio + if json.Unmarshal(b, &v) == nil { + u.stdio = &v + return nil + } + } + return nil +} +func (u McpServer) MarshalJSON() ([]byte, error) { + if u.Http != nil { + var m map[string]any + _b, _e := json.Marshal(*u.Http) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["type"] = "http" + return json.Marshal(m) + } + if u.Sse != nil { + var m map[string]any + _b, _e := json.Marshal(*u.Sse) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["type"] = "sse" + return json.Marshal(m) + } + if u.stdio != nil { + var m map[string]any + _b, _e := json.Marshal(*u.stdio) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + return json.Marshal(m) + } + return []byte{}, nil +} + // Request parameters for creating a new session. See protocol docs: [Creating a Session](https://agentclientprotocol.com/protocol/session-setup#creating-a-session) type NewSessionRequest struct { // The working directory for this session. Must be an absolute path. @@ -1506,8 +1810,8 @@ func (v *NewSessionRequest) Validate() error { // Response from creating a new session. See protocol docs: [Creating a Session](https://agentclientprotocol.com/protocol/session-setup#creating-a-session) type NewSessionResponse struct { - // **UNSTABLE** Commands that may be executed via 'session/prompt' requests - AvailableCommands []AvailableCommand `json:"availableCommands,omitempty"` + // **UNSTABLE** This field is not part of the spec, and may be removed or changed at any point. + Modes *SessionModeState `json:"modes,omitempty"` // Unique identifier for the created session. Used in all subsequent requests for this conversation. SessionId SessionId `json:"sessionId"` } @@ -1882,6 +2186,22 @@ const ( // A unique identifier for a conversation session between a client and agent. Sessions maintain their own context, conversation history, and state, allowing multiple independent interactions with the same agent. # Example ”' use agent_client_protocol::SessionId; use std::sync::Arc; let session_id = SessionId(Arc::from("sess_abc123def456")); ”' See protocol docs: [Session ID](https://agentclientprotocol.com/protocol/session-setup#session-id) type SessionId string +// **UNSTABLE** This type is not part of the spec, and may be removed or changed at any point. +type SessionMode struct { + Description *string `json:"description,omitempty"` + Id SessionModeId `json:"id"` + Name string `json:"name"` +} + +// **UNSTABLE** This type is not part of the spec, and may be removed or changed at any point. +type SessionModeId string + +// **UNSTABLE** This type is not part of the spec, and may be removed or changed at any point. +type SessionModeState struct { + AvailableModes []SessionMode `json:"availableModes"` + CurrentModeId SessionModeId `json:"currentModeId"` +} + // Notification containing a session update from the agent. Used to stream real-time progress and results during prompt processing. See protocol docs: [Agent Reports Output](https://agentclientprotocol.com/protocol/prompt-turn#3-agent-reports-output) type SessionNotification struct { // The ID of the session this update pertains to. @@ -1962,13 +2282,27 @@ type SessionUpdatePlan struct { SessionUpdate string `json:"sessionUpdate"` } +// Available commands are ready or have changed +type SessionUpdateAvailableCommandsUpdate struct { + AvailableCommands []AvailableCommand `json:"availableCommands"` + SessionUpdate string `json:"sessionUpdate"` +} + +// The current mode of the session has changed +type SessionUpdateCurrentModeUpdate struct { + CurrentModeId SessionModeId `json:"currentModeId"` + SessionUpdate string `json:"sessionUpdate"` +} + type SessionUpdate struct { - UserMessageChunk *SessionUpdateUserMessageChunk `json:"-"` - AgentMessageChunk *SessionUpdateAgentMessageChunk `json:"-"` - AgentThoughtChunk *SessionUpdateAgentThoughtChunk `json:"-"` - ToolCall *SessionUpdateToolCall `json:"-"` - ToolCallUpdate *SessionUpdateToolCallUpdate `json:"-"` - Plan *SessionUpdatePlan `json:"-"` + UserMessageChunk *SessionUpdateUserMessageChunk `json:"-"` + AgentMessageChunk *SessionUpdateAgentMessageChunk `json:"-"` + AgentThoughtChunk *SessionUpdateAgentThoughtChunk `json:"-"` + ToolCall *SessionUpdateToolCall `json:"-"` + ToolCallUpdate *SessionUpdateToolCallUpdate `json:"-"` + Plan *SessionUpdatePlan `json:"-"` + AvailableCommandsUpdate *SessionUpdateAvailableCommandsUpdate `json:"-"` + CurrentModeUpdate *SessionUpdateCurrentModeUpdate `json:"-"` } func (u *SessionUpdate) UnmarshalJSON(b []byte) error { @@ -2024,6 +2358,20 @@ func (u *SessionUpdate) UnmarshalJSON(b []byte) error { } u.Plan = &v return nil + case "available_commands_update": + var v SessionUpdateAvailableCommandsUpdate + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.AvailableCommandsUpdate = &v + return nil + case "current_mode_update": + var v SessionUpdateCurrentModeUpdate + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.CurrentModeUpdate = &v + return nil } } { @@ -2131,6 +2479,40 @@ func (u *SessionUpdate) UnmarshalJSON(b []byte) error { return nil } } + { + var v SessionUpdateAvailableCommandsUpdate + var match bool = true + if _, ok := m["sessionUpdate"]; !ok { + match = false + } + if _, ok := m["availableCommands"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.AvailableCommandsUpdate = &v + return nil + } + } + { + var v SessionUpdateCurrentModeUpdate + var match bool = true + if _, ok := m["sessionUpdate"]; !ok { + match = false + } + if _, ok := m["currentModeId"]; !ok { + match = false + } + if match { + if json.Unmarshal(b, &v) != nil { + return errors.New("invalid variant payload") + } + u.CurrentModeUpdate = &v + return nil + } + } { var v SessionUpdateUserMessageChunk if json.Unmarshal(b, &v) == nil { @@ -2173,6 +2555,20 @@ func (u *SessionUpdate) UnmarshalJSON(b []byte) error { return nil } } + { + var v SessionUpdateAvailableCommandsUpdate + if json.Unmarshal(b, &v) == nil { + u.AvailableCommandsUpdate = &v + return nil + } + } + { + var v SessionUpdateCurrentModeUpdate + if json.Unmarshal(b, &v) == nil { + u.CurrentModeUpdate = &v + return nil + } + } return nil } func (u SessionUpdate) MarshalJSON() ([]byte, error) { @@ -2248,6 +2644,30 @@ func (u SessionUpdate) MarshalJSON() ([]byte, error) { m["sessionUpdate"] = "plan" return json.Marshal(m) } + if u.AvailableCommandsUpdate != nil { + var m map[string]any + _b, _e := json.Marshal(*u.AvailableCommandsUpdate) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["sessionUpdate"] = "available_commands_update" + return json.Marshal(m) + } + if u.CurrentModeUpdate != nil { + var m map[string]any + _b, _e := json.Marshal(*u.CurrentModeUpdate) + if _e != nil { + return []byte{}, _e + } + if json.Unmarshal(_b, &m) != nil { + return []byte{}, errors.New("invalid variant payload") + } + m["sessionUpdate"] = "current_mode_update" + return json.Marshal(m) + } return []byte{}, nil } @@ -2271,12 +2691,36 @@ func (u *SessionUpdate) Validate() error { if u.Plan != nil { count++ } + if u.AvailableCommandsUpdate != nil { + count++ + } + if u.CurrentModeUpdate != nil { + count++ + } if count != 1 { return errors.New("SessionUpdate must have exactly one variant set") } return nil } +// **UNSTABLE** This type is not part of the spec, and may be removed or changed at any point. +type SetSessionModeRequest struct { + ModeId SessionModeId `json:"modeId"` + SessionId SessionId `json:"sessionId"` +} + +func (v *SetSessionModeRequest) Validate() error { + return nil +} + +// **UNSTABLE** This type is not part of the spec, and may be removed or changed at any point. +// SetSessionModeResponse is a union or complex schema; represented generically. +type SetSessionModeResponse any + +func (v *SetSessionModeResponse) Validate() error { + return nil +} + // Reasons why an agent stops processing a prompt turn. See protocol docs: [Stop Reasons](https://agentclientprotocol.com/protocol/prompt-turn#stop-reasons) type StopReason string @@ -2601,15 +3045,16 @@ func (t *ToolCallUpdate) Validate() error { type ToolKind string const ( - ToolKindRead ToolKind = "read" - ToolKindEdit ToolKind = "edit" - ToolKindDelete ToolKind = "delete" - ToolKindMove ToolKind = "move" - ToolKindSearch ToolKind = "search" - ToolKindExecute ToolKind = "execute" - ToolKindThink ToolKind = "think" - ToolKindFetch ToolKind = "fetch" - ToolKindOther ToolKind = "other" + ToolKindRead ToolKind = "read" + ToolKindEdit ToolKind = "edit" + ToolKindDelete ToolKind = "delete" + ToolKindMove ToolKind = "move" + ToolKindSearch ToolKind = "search" + ToolKindExecute ToolKind = "execute" + ToolKindThink ToolKind = "think" + ToolKindFetch ToolKind = "fetch" + ToolKindSwitchMode ToolKind = "switch_mode" + ToolKindOther ToolKind = "other" ) type WaitForTerminalExitRequest struct { @@ -2663,7 +3108,7 @@ type Agent interface { // AgentLoader defines optional support for loading sessions. Implement and advertise the capability to enable 'session/load'. type AgentLoader interface { - LoadSession(ctx context.Context, params LoadSessionRequest) error + LoadSession(ctx context.Context, params LoadSessionRequest) (LoadSessionResponse, error) } type Client interface { ReadTextFile(ctx context.Context, params ReadTextFileRequest) (ReadTextFileResponse, error) diff --git a/rust/markdown_generator.rs b/rust/markdown_generator.rs index 4f654c6..e776fd0 100644 --- a/rust/markdown_generator.rs +++ b/rust/markdown_generator.rs @@ -656,11 +656,9 @@ impl SideDocs { } fn extract_side_docs() -> SideDocs { - // Try to run cargo rustdoc with the current toolchain first (works with rustup via rust-toolchain.toml - // and with Nix-provided nightly toolchains). If that fails, fall back to the rustup-style '+nightly' - // invocation for environments where a default stable toolchain is active. - let mut output = Command::new("cargo") + let output = Command::new("cargo") .args([ + "+nightly", "rustdoc", "--lib", "--", @@ -673,28 +671,10 @@ fn extract_side_docs() -> SideDocs { .unwrap(); if !output.status.success() { - let fallback = Command::new("cargo") - .args([ - "+nightly", - "rustdoc", - "--lib", - "--", - "-Z", - "unstable-options", - "--output-format", - "json", - ]) - .output() - .unwrap(); - - if !fallback.status.success() { - panic!( - "Failed to generate rustdoc JSON. First attempt (no +nightly): {}\nFallback (+nightly) failed: {}", - String::from_utf8_lossy(&output.stderr), - String::from_utf8_lossy(&fallback.stderr) - ); - } - output = fallback; + panic!( + "Failed to generate rustdoc JSON: {}", + String::from_utf8_lossy(&output.stderr) + ); } // Parse the JSON output From 1549a98f86269a53ae54bee8c74e97e90bfe39df Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Fri, 12 Sep 2025 10:23:25 +0200 Subject: [PATCH 21/22] feat: update LoadSession method to return LoadSessionResponse Change-Id: I8b0d3562a5abd7f61416eb4efd8f6b52c0c3f3c8 Signed-off-by: Thomas Kosiewski --- go/acp_test.go | 30 ++++++++++++------- go/cmd/generate/internal/emit/types.go | 16 +++++++--- go/example/agent/main.go | 4 ++- go/json_parity_test.go | 13 +++++++- .../json_golden/initialize_response.json | 1 + go/types_gen.go | 29 +++++++++--------- 6 files changed, 61 insertions(+), 32 deletions(-) diff --git a/go/acp_test.go b/go/acp_test.go index 96b7958..9d5984c 100644 --- a/go/acp_test.go +++ b/go/acp_test.go @@ -49,7 +49,7 @@ func (c clientFuncs) SessionUpdate(ctx context.Context, n SessionNotification) e type agentFuncs struct { InitializeFunc func(context.Context, InitializeRequest) (InitializeResponse, error) NewSessionFunc func(context.Context, NewSessionRequest) (NewSessionResponse, error) - LoadSessionFunc func(context.Context, LoadSessionRequest) error + LoadSessionFunc func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) AuthenticateFunc func(context.Context, AuthenticateRequest) error PromptFunc func(context.Context, PromptRequest) (PromptResponse, error) CancelFunc func(context.Context, CancelNotification) error @@ -74,11 +74,11 @@ func (a agentFuncs) NewSession(ctx context.Context, p NewSessionRequest) (NewSes return NewSessionResponse{}, nil } -func (a agentFuncs) LoadSession(ctx context.Context, p LoadSessionRequest) error { +func (a agentFuncs) LoadSession(ctx context.Context, p LoadSessionRequest) (LoadSessionResponse, error) { if a.LoadSessionFunc != nil { return a.LoadSessionFunc(ctx, p) } - return nil + return LoadSessionResponse{}, nil } func (a agentFuncs) Authenticate(ctx context.Context, p AuthenticateRequest) error { @@ -127,8 +127,8 @@ func TestConnectionHandlesErrorsBidirectional(t *testing.T) { NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{}, &RequestError{Code: -32603, Message: "Failed to create session"} }, - LoadSessionFunc: func(context.Context, LoadSessionRequest) error { - return &RequestError{Code: -32603, Message: "Failed to load session"} + LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) { + return LoadSessionResponse{}, &RequestError{Code: -32603, Message: "Failed to load session"} }, AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return &RequestError{Code: -32603, Message: "Authentication failed"} @@ -181,7 +181,9 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) { NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{SessionId: "test-session"}, nil }, - LoadSessionFunc: func(context.Context, LoadSessionRequest) error { return nil }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) { + return LoadSessionResponse{}, nil + }, AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil }, PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) { return PromptResponse{StopReason: "end_turn"}, nil @@ -254,9 +256,9 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) { push("newSession called: " + p.Cwd) return NewSessionResponse{SessionId: "test-session"}, nil }, - LoadSessionFunc: func(_ context.Context, p LoadSessionRequest) error { + LoadSessionFunc: func(_ context.Context, p LoadSessionRequest) (LoadSessionResponse, error) { push("loadSession called: " + string(p.SessionId)) - return nil + return LoadSessionResponse{}, nil }, AuthenticateFunc: func(_ context.Context, p AuthenticateRequest) error { push("authenticate called: " + string(p.MethodId)) @@ -354,7 +356,9 @@ func TestConnectionHandlesNotifications(t *testing.T) { NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{SessionId: "test-session"}, nil }, - LoadSessionFunc: func(context.Context, LoadSessionRequest) error { return nil }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) { + return LoadSessionResponse{}, nil + }, AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil }, PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) { return PromptResponse{StopReason: "end_turn"}, nil @@ -425,7 +429,9 @@ func TestConnectionHandlesInitialize(t *testing.T) { NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{SessionId: "test-session"}, nil }, - LoadSessionFunc: func(context.Context, LoadSessionRequest) error { return nil }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) { + return LoadSessionResponse{}, nil + }, AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil }, PromptFunc: func(context.Context, PromptRequest) (PromptResponse, error) { return PromptResponse{StopReason: "end_turn"}, nil @@ -472,7 +478,9 @@ func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) { NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { return NewSessionResponse{SessionId: "s-1"}, nil }, - LoadSessionFunc: func(context.Context, LoadSessionRequest) error { return nil }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) { + return LoadSessionResponse{}, nil + }, AuthenticateFunc: func(context.Context, AuthenticateRequest) error { return nil }, PromptFunc: func(ctx context.Context, p PromptRequest) (PromptResponse, error) { <-ctx.Done() diff --git a/go/cmd/generate/internal/emit/types.go b/go/cmd/generate/internal/emit/types.go index aa7a608..80797c6 100644 --- a/go/cmd/generate/internal/emit/types.go +++ b/go/cmd/generate/internal/emit/types.go @@ -127,9 +127,10 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { defaults = append(defaults, *dp) } if _, ok := req[pk]; !ok { - // Default: omit if empty for optional fields, unless schema specifies - // a default array/object (always present on wire). - if dp == nil || (dp.kind != KindArray && dp.kind != KindObject) { + // Default: omit if empty for optional fields. + // Keep always-present behavior only for defaults where the zero value is nil (slice/map). + // For typed object defaults (non-nilable), still allow omission on the wire. + if dp == nil || (dp.kind != KindArray && dp.kind != KindObject) || (dp != nil && !dp.nilable) { tag = pk + ",omitempty" } } @@ -198,6 +199,11 @@ func WriteTypesJen(outDir string, schema *load.Schema, meta *load.Meta) error { case ir.PrimaryType(def) == "string" || ir.PrimaryType(def) == "integer" || ir.PrimaryType(def) == "number" || ir.PrimaryType(def) == "boolean": f.Type().Id(name).Add(primitiveJenType(ir.PrimaryType(def))) f.Line() + case ir.PrimaryType(def) == "object" && len(def.Properties) == 0: + // Empty object shape: emit a concrete empty struct so methods can be defined + // and the wire encoding is consistently {} rather than null. + f.Type().Id(name).Struct() + f.Line() default: f.Comment(fmt.Sprintf("%s is a union or complex schema; represented generically.", name)) f.Type().Id(name).Any() @@ -569,7 +575,9 @@ func emitUnion(f *File, name string, defs []*load.Definition, exactlyOne bool) { } } } - fieldName := tname + // Ensure Title-derived names are exported (e.g., "stdio" -> "Stdio"). + tname = util.ToExportedField(tname) + fieldName := util.ToExportedField(tname) dv := "" if discKey != "" { if pd := v.Properties[discKey]; pd != nil && pd.Const != nil { diff --git a/go/example/agent/main.go b/go/example/agent/main.go index dbb2939..906f8c2 100644 --- a/go/example/agent/main.go +++ b/go/example/agent/main.go @@ -56,7 +56,9 @@ func (a *exampleAgent) NewSession(ctx context.Context, params acp.NewSessionRequ func (a *exampleAgent) Authenticate(ctx context.Context, _ acp.AuthenticateRequest) error { return nil } -func (a *exampleAgent) LoadSession(ctx context.Context, _ acp.LoadSessionRequest) error { return nil } +func (a *exampleAgent) LoadSession(ctx context.Context, _ acp.LoadSessionRequest) (acp.LoadSessionResponse, error) { + return acp.LoadSessionResponse{}, nil +} func (a *exampleAgent) Cancel(ctx context.Context, params acp.CancelNotification) error { a.mu.Lock() diff --git a/go/json_parity_test.go b/go/json_parity_test.go index 5918a2c..13d1aa3 100644 --- a/go/json_parity_test.go +++ b/go/json_parity_test.go @@ -235,7 +235,18 @@ func TestJSONGolden_MethodPayloads(t *testing.T) { return InitializeResponse{ProtocolVersion: 1, AgentCapabilities: AgentCapabilities{LoadSession: true, PromptCapabilities: PromptCapabilities{Image: true, Audio: true, EmbeddedContext: true}}} })) t.Run("new_session_request", runGolden(func() NewSessionRequest { - return NewSessionRequest{Cwd: "/home/user/project", McpServers: []McpServer{{Name: "filesystem", Command: "/path/to/mcp-server", Args: []string{"--stdio"}, Env: []EnvVariable{}}}} + return NewSessionRequest{ + Cwd: "/home/user/project", McpServers: []McpServer{ + { + Stdio: &Stdio{ + Name: "filesystem", + Command: "/path/to/mcp-server", + Args: []string{"--stdio"}, + Env: []EnvVariable{}, + }, + }, + }, + } })) t.Run("new_session_response", runGolden(func() NewSessionResponse { return NewSessionResponse{SessionId: "sess_abc123def456"} })) t.Run("prompt_request", runGolden(func() PromptRequest { diff --git a/go/testdata/json_golden/initialize_response.json b/go/testdata/json_golden/initialize_response.json index 6524b96..66abb81 100644 --- a/go/testdata/json_golden/initialize_response.json +++ b/go/testdata/json_golden/initialize_response.json @@ -2,6 +2,7 @@ "protocolVersion": 1, "agentCapabilities": { "loadSession": true, + "mcpCapabilities": {}, "promptCapabilities": { "image": true, "audio": true, diff --git a/go/types_gen.go b/go/types_gen.go index 9abbbdd..a3c7346 100644 --- a/go/types_gen.go +++ b/go/types_gen.go @@ -18,11 +18,11 @@ type AgentCapabilities struct { // MCP capabilities supported by the agent. // // Defaults to {"http":false,"sse":false} if unset. - McpCapabilities McpCapabilities `json:"mcpCapabilities"` + McpCapabilities McpCapabilities `json:"mcpCapabilities,omitempty"` // Prompt capabilities supported by the agent. // // Defaults to {"audio":false,"embeddedContext":false,"image":false} if unset. - PromptCapabilities PromptCapabilities `json:"promptCapabilities"` + PromptCapabilities PromptCapabilities `json:"promptCapabilities,omitempty"` } func (v AgentCapabilities) MarshalJSON() ([]byte, error) { @@ -515,7 +515,7 @@ type ClientCapabilities struct { // File system capabilities supported by the client. Determines which file operations the agent can request. // // Defaults to {"readTextFile":false,"writeTextFile":false} if unset. - Fs FileSystemCapability `json:"fs"` + Fs FileSystemCapability `json:"fs,omitempty"` // **UNSTABLE** This capability is not part of the spec yet, and may be removed or changed at any point. // // Defaults to false if unset. @@ -1414,7 +1414,7 @@ type InitializeRequest struct { // Capabilities supported by the client. // // Defaults to {"fs":{"readTextFile":false,"writeTextFile":false},"terminal":false} if unset. - ClientCapabilities ClientCapabilities `json:"clientCapabilities"` + ClientCapabilities ClientCapabilities `json:"clientCapabilities,omitempty"` // The latest protocol version supported by the client. ProtocolVersion ProtocolVersion `json:"protocolVersion"` } @@ -1455,7 +1455,7 @@ type InitializeResponse struct { // Capabilities supported by the agent. // // Defaults to {"loadSession":false,"mcpCapabilities":{"http":false,"sse":false},"promptCapabilities":{"audio":false,"embeddedContext":false,"image":false}} if unset. - AgentCapabilities AgentCapabilities `json:"agentCapabilities"` + AgentCapabilities AgentCapabilities `json:"agentCapabilities,omitempty"` // Authentication methods supported by the agent. // // Defaults to [] if unset. @@ -1615,7 +1615,7 @@ type McpServerSse struct { } // Stdio transport configuration All Agents MUST support this transport. -type stdio struct { +type Stdio struct { // Command-line arguments to pass to the MCP server. Args []string `json:"args"` // Path to the MCP server executable. @@ -1629,7 +1629,7 @@ type stdio struct { type McpServer struct { Http *McpServerHttp `json:"-"` Sse *McpServerSse `json:"-"` - stdio *stdio `json:"-"` + Stdio *Stdio `json:"-"` } func (u *McpServer) UnmarshalJSON(b []byte) error { @@ -1706,7 +1706,7 @@ func (u *McpServer) UnmarshalJSON(b []byte) error { } } { - var v stdio + var v Stdio var match bool = true if _, ok := m["name"]; !ok { match = false @@ -1724,7 +1724,7 @@ func (u *McpServer) UnmarshalJSON(b []byte) error { if json.Unmarshal(b, &v) != nil { return errors.New("invalid variant payload") } - u.stdio = &v + u.Stdio = &v return nil } } @@ -1743,9 +1743,9 @@ func (u *McpServer) UnmarshalJSON(b []byte) error { } } { - var v stdio + var v Stdio if json.Unmarshal(b, &v) == nil { - u.stdio = &v + u.Stdio = &v return nil } } @@ -1776,9 +1776,9 @@ func (u McpServer) MarshalJSON() ([]byte, error) { m["type"] = "sse" return json.Marshal(m) } - if u.stdio != nil { + if u.Stdio != nil { var m map[string]any - _b, _e := json.Marshal(*u.stdio) + _b, _e := json.Marshal(*u.Stdio) if _e != nil { return []byte{}, _e } @@ -2714,8 +2714,7 @@ func (v *SetSessionModeRequest) Validate() error { } // **UNSTABLE** This type is not part of the spec, and may be removed or changed at any point. -// SetSessionModeResponse is a union or complex schema; represented generically. -type SetSessionModeResponse any +type SetSessionModeResponse struct{} func (v *SetSessionModeResponse) Validate() error { return nil From f952023e89695dbcc35543664ed47ea1c57a7501 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Fri, 12 Sep 2025 10:47:08 +0200 Subject: [PATCH 22/22] feat: consolidate terminal methods into main Client interface and add KillTerminalCommand Change-Id: Ie34241477e4f04a55a57dd04639c68f9b2b6fcc3 Signed-off-by: Thomas Kosiewski --- go/acp_test.go | 58 +++++++++++++++++--- go/agent_gen.go | 3 ++ go/client_gen.go | 36 ++++++------- go/example/claude-code/main.go | 6 +++ go/example/client/main.go | 11 ++-- go/example/gemini/main.go | 11 ++-- go/example_client_test.go | 23 ++++++++ go/example_gemini_test.go | 17 ++++++ go/types_gen.go | 97 ++++++++++++++++++++++------------ 9 files changed, 194 insertions(+), 68 deletions(-) diff --git a/go/acp_test.go b/go/acp_test.go index 9d5984c..b180a7e 100644 --- a/go/acp_test.go +++ b/go/acp_test.go @@ -14,6 +14,12 @@ type clientFuncs struct { ReadTextFileFunc func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) RequestPermissionFunc func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) SessionUpdateFunc func(context.Context, SessionNotification) error + // Terminal-related handlers + CreateTerminalFunc func(context.Context, CreateTerminalRequest) (CreateTerminalResponse, error) + KillTerminalCommandFunc func(context.Context, KillTerminalCommandRequest) error + ReleaseTerminalFunc func(context.Context, ReleaseTerminalRequest) error + TerminalOutputFunc func(context.Context, TerminalOutputRequest) (TerminalOutputResponse, error) + WaitForTerminalExitFunc func(context.Context, WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) } var _ Client = (*clientFuncs)(nil) @@ -46,6 +52,46 @@ func (c clientFuncs) SessionUpdate(ctx context.Context, n SessionNotification) e return nil } +// CreateTerminal implements Client. +func (c *clientFuncs) CreateTerminal(ctx context.Context, params CreateTerminalRequest) (CreateTerminalResponse, error) { + if c.CreateTerminalFunc != nil { + return c.CreateTerminalFunc(ctx, params) + } + return CreateTerminalResponse{TerminalId: "test-terminal"}, nil +} + +// KillTerminalCommand implements Client. +func (c *clientFuncs) KillTerminalCommand(ctx context.Context, params KillTerminalCommandRequest) error { + if c.KillTerminalCommandFunc != nil { + return c.KillTerminalCommandFunc(ctx, params) + } + return nil +} + +// ReleaseTerminal implements Client. +func (c *clientFuncs) ReleaseTerminal(ctx context.Context, params ReleaseTerminalRequest) error { + if c.ReleaseTerminalFunc != nil { + return c.ReleaseTerminalFunc(ctx, params) + } + return nil +} + +// TerminalOutput implements Client. +func (c *clientFuncs) TerminalOutput(ctx context.Context, params TerminalOutputRequest) (TerminalOutputResponse, error) { + if c.TerminalOutputFunc != nil { + return c.TerminalOutputFunc(ctx, params) + } + return TerminalOutputResponse{Output: "ok", Truncated: false}, nil +} + +// WaitForTerminalExit implements Client. +func (c *clientFuncs) WaitForTerminalExit(ctx context.Context, params WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) { + if c.WaitForTerminalExitFunc != nil { + return c.WaitForTerminalExitFunc(ctx, params) + } + return WaitForTerminalExitResponse{}, nil +} + type agentFuncs struct { InitializeFunc func(context.Context, InitializeRequest) (InitializeResponse, error) NewSessionFunc func(context.Context, NewSessionRequest) (NewSessionResponse, error) @@ -108,7 +154,7 @@ func TestConnectionHandlesErrorsBidirectional(t *testing.T) { c2aR, c2aW := io.Pipe() a2cR, a2cW := io.Pipe() - c := NewClientSideConnection(clientFuncs{ + c := NewClientSideConnection(&clientFuncs{ WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return &RequestError{Code: -32603, Message: "Write failed"} }, @@ -158,7 +204,7 @@ func TestConnectionHandlesConcurrentRequests(t *testing.T) { var mu sync.Mutex requestCount := 0 - _ = NewClientSideConnection(clientFuncs{ + _ = NewClientSideConnection(&clientFuncs{ WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { mu.Lock() requestCount++ @@ -229,7 +275,7 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) { var log []string push := func(s string) { mu.Lock(); defer mu.Unlock(); log = append(log, s) } - cs := NewClientSideConnection(clientFuncs{ + cs := NewClientSideConnection(&clientFuncs{ WriteTextFileFunc: func(_ context.Context, req WriteTextFileRequest) error { push("writeTextFile called: " + req.Path) return nil @@ -329,7 +375,7 @@ func TestConnectionHandlesNotifications(t *testing.T) { var logs []string push := func(s string) { mu.Lock(); logs = append(logs, s); mu.Unlock() } - clientSide := NewClientSideConnection(clientFuncs{ + clientSide := NewClientSideConnection(&clientFuncs{ WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return nil }, ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) { return ReadTextFileResponse{Content: "test"}, nil @@ -400,7 +446,7 @@ func TestConnectionHandlesInitialize(t *testing.T) { c2aR, c2aW := io.Pipe() a2cR, a2cW := io.Pipe() - agentConn := NewClientSideConnection(clientFuncs{ + agentConn := NewClientSideConnection(&clientFuncs{ WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return nil }, ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) { return ReadTextFileResponse{Content: "test"}, nil @@ -501,7 +547,7 @@ func TestPromptCancellationSendsCancelAndAllowsNewSession(t *testing.T) { }, a2cW, c2aR) // Client side - cs := NewClientSideConnection(clientFuncs{ + cs := NewClientSideConnection(&clientFuncs{ WriteTextFileFunc: func(context.Context, WriteTextFileRequest) error { return nil }, ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) { return ReadTextFileResponse{Content: ""}, nil diff --git a/go/agent_gen.go b/go/agent_gen.go index 883fb21..d4abcd4 100644 --- a/go/agent_gen.go +++ b/go/agent_gen.go @@ -129,6 +129,9 @@ func (c *AgentSideConnection) CreateTerminal(ctx context.Context, params CreateT resp, err := SendRequest[CreateTerminalResponse](c.conn, ctx, ClientMethodTerminalCreate, params) return resp, err } +func (c *AgentSideConnection) KillTerminalCommand(ctx context.Context, params KillTerminalCommandRequest) error { + return c.conn.SendRequestNoResult(ctx, ClientMethodTerminalKill, params) +} func (c *AgentSideConnection) TerminalOutput(ctx context.Context, params TerminalOutputRequest) (TerminalOutputResponse, error) { resp, err := SendRequest[TerminalOutputResponse](c.conn, ctx, ClientMethodTerminalOutput, params) return resp, err diff --git a/go/client_gen.go b/go/client_gen.go index ced57c4..fdb4c01 100644 --- a/go/client_gen.go +++ b/go/client_gen.go @@ -67,15 +67,23 @@ func (c *ClientSideConnection) handle(ctx context.Context, method string, params if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - t, ok := c.client.(ClientTerminal) - if !ok { - return nil, NewMethodNotFound(method) - } - resp, err := t.CreateTerminal(ctx, p) + resp, err := c.client.CreateTerminal(ctx, p) if err != nil { return nil, toReqErr(err) } return resp, nil + case ClientMethodTerminalKill: + var p KillTerminalCommandRequest + if err := json.Unmarshal(params, &p); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := p.Validate(); err != nil { + return nil, NewInvalidParams(map[string]any{"error": err.Error()}) + } + if err := c.client.KillTerminalCommand(ctx, p); err != nil { + return nil, toReqErr(err) + } + return nil, nil case ClientMethodTerminalOutput: var p TerminalOutputRequest if err := json.Unmarshal(params, &p); err != nil { @@ -84,11 +92,7 @@ func (c *ClientSideConnection) handle(ctx context.Context, method string, params if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - t, ok := c.client.(ClientTerminal) - if !ok { - return nil, NewMethodNotFound(method) - } - resp, err := t.TerminalOutput(ctx, p) + resp, err := c.client.TerminalOutput(ctx, p) if err != nil { return nil, toReqErr(err) } @@ -101,11 +105,7 @@ func (c *ClientSideConnection) handle(ctx context.Context, method string, params if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - t, ok := c.client.(ClientTerminal) - if !ok { - return nil, NewMethodNotFound(method) - } - if err := t.ReleaseTerminal(ctx, p); err != nil { + if err := c.client.ReleaseTerminal(ctx, p); err != nil { return nil, toReqErr(err) } return nil, nil @@ -117,11 +117,7 @@ func (c *ClientSideConnection) handle(ctx context.Context, method string, params if err := p.Validate(); err != nil { return nil, NewInvalidParams(map[string]any{"error": err.Error()}) } - t, ok := c.client.(ClientTerminal) - if !ok { - return nil, NewMethodNotFound(method) - } - resp, err := t.WaitForTerminalExit(ctx, p) + resp, err := c.client.WaitForTerminalExit(ctx, p) if err != nil { return nil, toReqErr(err) } diff --git a/go/example/claude-code/main.go b/go/example/claude-code/main.go index fcdbbda..aefd9f8 100644 --- a/go/example/claude-code/main.go +++ b/go/example/claude-code/main.go @@ -154,6 +154,12 @@ func (c *replClient) WaitForTerminalExit(ctx context.Context, params acp.WaitFor return acp.WaitForTerminalExitResponse{}, nil } +// KillTerminalCommand implements acp.Client. +func (c *replClient) KillTerminalCommand(ctx context.Context, params acp.KillTerminalCommandRequest) error { + fmt.Printf("[Client] KillTerminalCommand: %v\n", params) + return nil +} + func main() { yolo := flag.Bool("yolo", false, "Auto-approve permission prompts") flag.Parse() diff --git a/go/example/client/main.go b/go/example/client/main.go index be85a4d..0b8a987 100644 --- a/go/example/client/main.go +++ b/go/example/client/main.go @@ -16,10 +16,7 @@ import ( type exampleClient struct{} -var ( - _ acp.Client = (*exampleClient)(nil) - _ acp.ClientTerminal = (*exampleClient)(nil) -) +var _ acp.Client = (*exampleClient)(nil) func (e *exampleClient) RequestPermission(ctx context.Context, params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { title := "" @@ -153,6 +150,12 @@ func (e *exampleClient) WaitForTerminalExit(ctx context.Context, params acp.Wait return acp.WaitForTerminalExitResponse{}, nil } +// KillTerminalCommand implements acp.Client. +func (c *exampleClient) KillTerminalCommand(ctx context.Context, params acp.KillTerminalCommandRequest) error { + fmt.Printf("[Client] KillTerminalCommand: %v\n", params) + return nil +} + func main() { // If args provided, treat them as agent program + args. Otherwise run the Go agent example. ctx, cancel := context.WithCancel(context.Background()) diff --git a/go/example/gemini/main.go b/go/example/gemini/main.go index 694dea1..868263a 100644 --- a/go/example/gemini/main.go +++ b/go/example/gemini/main.go @@ -21,10 +21,7 @@ type replClient struct { autoApprove bool } -var ( - _ acp.Client = (*replClient)(nil) - _ acp.ClientTerminal = (*replClient)(nil) -) +var _ acp.Client = (*replClient)(nil) func (c *replClient) RequestPermission(ctx context.Context, params acp.RequestPermissionRequest) (acp.RequestPermissionResponse, error) { if c.autoApprove { @@ -157,6 +154,12 @@ func (c *replClient) WaitForTerminalExit(ctx context.Context, params acp.WaitFor return acp.WaitForTerminalExitResponse{}, nil } +// KillTerminalCommand implements acp.Client. +func (c *replClient) KillTerminalCommand(ctx context.Context, params acp.KillTerminalCommandRequest) error { + fmt.Printf("[Client] KillTerminalCommand: %v\n", params) + return nil +} + func main() { binary := flag.String("gemini", "gemini", "Path to the Gemini CLI binary") model := flag.String("model", "", "Model to pass to Gemini (optional)") diff --git a/go/example_client_test.go b/go/example_client_test.go index e90e70c..438c11f 100644 --- a/go/example_client_test.go +++ b/go/example_client_test.go @@ -85,6 +85,29 @@ func (clientExample) ReadTextFile(ctx context.Context, p ReadTextFileRequest) (R return ReadTextFileResponse{Content: content}, nil } +// Terminal interface implementations (minimal stubs for examples) +func (clientExample) CreateTerminal(ctx context.Context, p CreateTerminalRequest) (CreateTerminalResponse, error) { + // Return a dummy terminal id + return CreateTerminalResponse{TerminalId: "t-1"}, nil +} + +func (clientExample) KillTerminalCommand(ctx context.Context, p KillTerminalCommandRequest) error { + return nil +} + +func (clientExample) ReleaseTerminal(ctx context.Context, p ReleaseTerminalRequest) error { + return nil +} + +func (clientExample) TerminalOutput(ctx context.Context, p TerminalOutputRequest) (TerminalOutputResponse, error) { + // Provide non-empty output to satisfy validation + return TerminalOutputResponse{Output: "ok", Truncated: false}, nil +} + +func (clientExample) WaitForTerminalExit(ctx context.Context, p WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) { + return WaitForTerminalExitResponse{}, nil +} + // Example_client launches the Go agent example, negotiates protocol, // opens a session, and sends a simple prompt. func Example_client() { diff --git a/go/example_gemini_test.go b/go/example_gemini_test.go index 97f4ac4..c891a29 100644 --- a/go/example_gemini_test.go +++ b/go/example_gemini_test.go @@ -32,6 +32,23 @@ func (geminiClient) ReadTextFile(ctx context.Context, _ ReadTextFileRequest) (Re } func (geminiClient) WriteTextFile(ctx context.Context, _ WriteTextFileRequest) error { return nil } +// Terminal interface implementations (minimal stubs for examples) +func (geminiClient) CreateTerminal(ctx context.Context, p CreateTerminalRequest) (CreateTerminalResponse, error) { + return CreateTerminalResponse{TerminalId: "t-1"}, nil +} + +func (geminiClient) KillTerminalCommand(ctx context.Context, p KillTerminalCommandRequest) error { + return nil +} +func (geminiClient) ReleaseTerminal(ctx context.Context, p ReleaseTerminalRequest) error { return nil } +func (geminiClient) TerminalOutput(ctx context.Context, p TerminalOutputRequest) (TerminalOutputResponse, error) { + return TerminalOutputResponse{Output: "ok", Truncated: false}, nil +} + +func (geminiClient) WaitForTerminalExit(ctx context.Context, p WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error) { + return WaitForTerminalExitResponse{}, nil +} + // Example_gemini connects to a Gemini CLI speaking ACP over stdio, // then initializes, opens a session, and sends a prompt. func Example_gemini() { diff --git a/go/types_gen.go b/go/types_gen.go index a3c7346..f8d5c04 100644 --- a/go/types_gen.go +++ b/go/types_gen.go @@ -107,7 +107,7 @@ type AgentRequest struct { TerminalOutputRequest *TerminalOutputRequest `json:"-"` ReleaseTerminalRequest *ReleaseTerminalRequest `json:"-"` WaitForTerminalExitRequest *WaitForTerminalExitRequest `json:"-"` - KillTerminalRequest *KillTerminalRequest `json:"-"` + KillTerminalCommandRequest *KillTerminalCommandRequest `json:"-"` } func (u *AgentRequest) UnmarshalJSON(b []byte) error { @@ -165,9 +165,9 @@ func (u *AgentRequest) UnmarshalJSON(b []byte) error { } } { - var v KillTerminalRequest + var v KillTerminalCommandRequest if json.Unmarshal(b, &v) == nil { - u.KillTerminalRequest = &v + u.KillTerminalCommandRequest = &v return nil } } @@ -251,9 +251,9 @@ func (u AgentRequest) MarshalJSON() ([]byte, error) { } return json.Marshal(m) } - if u.KillTerminalRequest != nil { + if u.KillTerminalCommandRequest != nil { var m map[string]any - _b, _e := json.Marshal(*u.KillTerminalRequest) + _b, _e := json.Marshal(*u.KillTerminalCommandRequest) if _e != nil { return []byte{}, _e } @@ -516,7 +516,7 @@ type ClientCapabilities struct { // // Defaults to {"readTextFile":false,"writeTextFile":false} if unset. Fs FileSystemCapability `json:"fs,omitempty"` - // **UNSTABLE** This capability is not part of the spec yet, and may be removed or changed at any point. + // Whether the Client support all 'terminal/*' methods. // // Defaults to false if unset. Terminal bool `json:"terminal,omitempty"` @@ -1238,13 +1238,20 @@ func (u *ContentBlock) Validate() error { return nil } +// Request to create a new terminal and execute a command. type CreateTerminalRequest struct { - Args []string `json:"args,omitempty"` - Command string `json:"command"` - Cwd *string `json:"cwd,omitempty"` - Env []EnvVariable `json:"env,omitempty"` - OutputByteLimit *int `json:"outputByteLimit,omitempty"` - SessionId SessionId `json:"sessionId"` + // Array of command arguments. + Args []string `json:"args,omitempty"` + // The command to execute. + Command string `json:"command"` + // Working directory for the command (absolute path). + Cwd *string `json:"cwd,omitempty"` + // Environment variables for the command. + Env []EnvVariable `json:"env,omitempty"` + // Maximum number of output bytes to retain. When the limit is exceeded, the Client truncates from the beginning of the output to stay within the limit. The Client MUST ensure truncation happens at a character boundary to maintain valid string output, even if this means the retained output is slightly less than the specified limit. + OutputByteLimit *int `json:"outputByteLimit,omitempty"` + // The session ID for this request. + SessionId SessionId `json:"sessionId"` } func (v *CreateTerminalRequest) Validate() error { @@ -1254,7 +1261,9 @@ func (v *CreateTerminalRequest) Validate() error { return nil } +// Response containing the ID of the created terminal. type CreateTerminalResponse struct { + // The unique identifier for the created terminal. TerminalId string `json:"terminalId"` } @@ -1504,12 +1513,15 @@ func (v *InitializeResponse) Validate() error { return nil } -type KillTerminalRequest struct { - SessionId SessionId `json:"sessionId"` - TerminalId string `json:"terminalId"` +// Request to kill a terminal command without releasing the terminal. +type KillTerminalCommandRequest struct { + // The session ID for this request. + SessionId SessionId `json:"sessionId"` + // The ID of the terminal to kill. + TerminalId string `json:"terminalId"` } -func (v *KillTerminalRequest) Validate() error { +func (v *KillTerminalCommandRequest) Validate() error { if v.TerminalId == "" { return fmt.Errorf("terminalId is required") } @@ -1962,9 +1974,9 @@ type ProtocolVersion int // Request to read content from a text file. Only available if the client supports the 'fs.readTextFile' capability. type ReadTextFileRequest struct { - // Optional maximum number of lines to read. + // Maximum number of lines to read. Limit *int `json:"limit,omitempty"` - // Optional line number to start reading from (1-based). + // Line number to start reading from (1-based). Line *int `json:"line,omitempty"` // Absolute path to the file to read. Path string `json:"path"` @@ -1991,9 +2003,12 @@ func (v *ReadTextFileResponse) Validate() error { return nil } +// Request to release a terminal and free its resources. type ReleaseTerminalRequest struct { - SessionId SessionId `json:"sessionId"` - TerminalId string `json:"terminalId"` + // The session ID for this request. + SessionId SessionId `json:"sessionId"` + // The ID of the terminal to release. + TerminalId string `json:"terminalId"` } func (v *ReleaseTerminalRequest) Validate() error { @@ -2731,14 +2746,20 @@ const ( StopReasonCancelled StopReason = "cancelled" ) +// Exit status of a terminal command. type TerminalExitStatus struct { - ExitCode *int `json:"exitCode,omitempty"` - Signal *string `json:"signal,omitempty"` + // The process exit code (may be null if terminated by signal). + ExitCode *int `json:"exitCode,omitempty"` + // The signal that terminated the process (may be null if exited normally). + Signal *string `json:"signal,omitempty"` } +// Request to get the current output and status of a terminal. type TerminalOutputRequest struct { - SessionId SessionId `json:"sessionId"` - TerminalId string `json:"terminalId"` + // The session ID for this request. + SessionId SessionId `json:"sessionId"` + // The ID of the terminal to get output from. + TerminalId string `json:"terminalId"` } func (v *TerminalOutputRequest) Validate() error { @@ -2748,10 +2769,14 @@ func (v *TerminalOutputRequest) Validate() error { return nil } +// Response containing the terminal output and exit status. type TerminalOutputResponse struct { + // Exit status if the command has completed. ExitStatus *TerminalExitStatus `json:"exitStatus,omitempty"` - Output string `json:"output"` - Truncated bool `json:"truncated"` + // The terminal output captured so far. + Output string `json:"output"` + // Whether the output was truncated due to byte limits. + Truncated bool `json:"truncated"` } func (v *TerminalOutputResponse) Validate() error { @@ -2813,6 +2838,7 @@ type ToolCallContentDiff struct { Type string `json:"type"` } +// Embed a terminal created with 'terminal/create' by its id. The terminal must be added before calling 'terminal/release'. See protocol docs: [Terminal](https://agentclientprotocol.com/protocol/terminal) type ToolCallContentTerminal struct { TerminalId string `json:"terminalId"` Type string `json:"type"` @@ -3056,9 +3082,12 @@ const ( ToolKindOther ToolKind = "other" ) +// Request to wait for a terminal command to exit. type WaitForTerminalExitRequest struct { - SessionId SessionId `json:"sessionId"` - TerminalId string `json:"terminalId"` + // The session ID for this request. + SessionId SessionId `json:"sessionId"` + // The ID of the terminal to wait for. + TerminalId string `json:"terminalId"` } func (v *WaitForTerminalExitRequest) Validate() error { @@ -3068,9 +3097,12 @@ func (v *WaitForTerminalExitRequest) Validate() error { return nil } +// Response containing the exit status of a terminal command. type WaitForTerminalExitResponse struct { - ExitCode *int `json:"exitCode,omitempty"` - Signal *string `json:"signal,omitempty"` + // The process exit code (may be null if terminated by signal). + ExitCode *int `json:"exitCode,omitempty"` + // The signal that terminated the process (may be null if exited normally). + Signal *string `json:"signal,omitempty"` } func (v *WaitForTerminalExitResponse) Validate() error { @@ -3114,11 +3146,8 @@ type Client interface { WriteTextFile(ctx context.Context, params WriteTextFileRequest) error RequestPermission(ctx context.Context, params RequestPermissionRequest) (RequestPermissionResponse, error) SessionUpdate(ctx context.Context, params SessionNotification) error -} - -// ClientTerminal defines terminal-related experimental methods (x-docs-ignore). Implement and advertise 'terminal: true' to enable 'terminal/*'. -type ClientTerminal interface { CreateTerminal(ctx context.Context, params CreateTerminalRequest) (CreateTerminalResponse, error) + KillTerminalCommand(ctx context.Context, params KillTerminalCommandRequest) error TerminalOutput(ctx context.Context, params TerminalOutputRequest) (TerminalOutputResponse, error) ReleaseTerminal(ctx context.Context, params ReleaseTerminalRequest) error WaitForTerminalExit(ctx context.Context, params WaitForTerminalExitRequest) (WaitForTerminalExitResponse, error)