From 6a3ce66e091e718c5261ee77c0151657b10c37ec Mon Sep 17 00:00:00 2001 From: tomholford <128995+tomholford@noreply.gitea.com> Date: Mon, 23 Mar 2026 11:53:00 +0000 Subject: [PATCH] feat(pull): add `draft` parameter for creating/updating draft PRs (#159) ## Summary The Gitea API has no native `draft` field on `CreatePullRequestOption` or `EditPullRequestOption`. Instead, Gitea treats PRs whose title starts with a WIP prefix (e.g. `WIP:`, `[WIP]`) as drafts. This adds a `draft` boolean parameter to the `pull_request_write` tool so MCP clients can create/update draft PRs without knowing about the WIP prefix convention. ## Changes - Add `draft` boolean parameter to `PullRequestWriteTool` schema, supported on `create` and `update` methods - Add `applyDraftPrefix()` helper that handles both default Gitea WIP prefixes (`WIP:`, `[WIP]`) case-insensitively - When `draft=true` and no prefix exists, prepend `WIP: `; when a prefix already exists, preserve the title as-is (no normalization) - When `draft=false`, strip any recognized WIP prefix - On `update`, if `draft` is set without `title`, auto-fetch the current PR title via GET - Add tests: 12 unit tests for `applyDraftPrefix`, 5 integration tests for create, 4 for edit --------- Co-authored-by: tomholford Reviewed-on: https://gitea.com/gitea/gitea-mcp/pulls/159 Reviewed-by: silverwind <2021+silverwind@noreply.gitea.com> Co-authored-by: tomholford <128995+tomholford@noreply.gitea.com> Co-committed-by: tomholford <128995+tomholford@noreply.gitea.com> --- operation/pull/pull.go | 45 +++++++ operation/pull/pull_test.go | 230 ++++++++++++++++++++++++++++++++++++ 2 files changed, 275 insertions(+) diff --git a/operation/pull/pull.go b/operation/pull/pull.go index 9ceedca..1c52022 100644 --- a/operation/pull/pull.go +++ b/operation/pull/pull.go @@ -3,6 +3,7 @@ package pull import ( "context" "fmt" + "strings" "gitea.com/gitea/gitea-mcp/pkg/gitea" "gitea.com/gitea/gitea-mcp/pkg/log" @@ -71,6 +72,7 @@ var ( mcp.WithBoolean("delete_branch", mcp.Description("delete branch after merge (for 'merge')")), mcp.WithArray("reviewers", mcp.Description("reviewer usernames (for 'add_reviewers', 'remove_reviewers')"), mcp.Items(map[string]any{"type": "string"})), mcp.WithArray("team_reviewers", mcp.Description("team reviewer names (for 'add_reviewers', 'remove_reviewers')"), mcp.Items(map[string]any{"type": "string"})), + mcp.WithBoolean("draft", mcp.Description("mark PR as draft (for 'create', 'update'). Gitea uses a 'WIP: ' title prefix for drafts.")), ) PullRequestReviewWriteTool = mcp.NewTool( @@ -271,6 +273,28 @@ func listRepoPullRequestsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp. return to.TextResult(slimPullRequests(pullRequests)) } +// defaultWIPPrefixes are the default Gitea title prefixes that mark a PR as +// work-in-progress / draft. Gitea matches these case-insensitively. +var defaultWIPPrefixes = []string{"WIP:", "[WIP]"} + +// applyDraftPrefix adds or removes a WIP title prefix that Gitea uses to mark +// pull requests as drafts. When the title already carries a recognized prefix +// and isDraft is true, the title is returned unchanged to avoid normalization. +func applyDraftPrefix(title string, isDraft bool) string { + for _, prefix := range defaultWIPPrefixes { + if len(title) >= len(prefix) && strings.EqualFold(title[:len(prefix)], prefix) { + if isDraft { + return title + } + return strings.TrimLeft(title[len(prefix):], " ") + } + } + if isDraft { + return "WIP: " + title + } + return title +} + func createPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { log.Debugf("Called createPullRequestFn") args := req.GetArguments() @@ -298,6 +322,11 @@ func createPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal if err != nil { return to.ErrorResult(err) } + + if draft, ok := args["draft"].(bool); ok { + title = applyDraftPrefix(title, draft) + } + client, err := gitea.ClientFromContext(ctx) if err != nil { return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) @@ -774,6 +803,22 @@ func editPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT if title, ok := args["title"].(string); ok { opt.Title = title } + if draft, ok := args["draft"].(bool); ok { + if opt.Title == "" { + // Fetch current title so the caller doesn't have to provide it + // just to toggle draft status. + client, err := gitea.ClientFromContext(ctx) + if err != nil { + return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err)) + } + pr, _, err := client.GetPullRequest(owner, repo, index) + if err != nil { + return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v err: %v", owner, repo, index, err)) + } + opt.Title = pr.Title + } + opt.Title = applyDraftPrefix(opt.Title, draft) + } if body, ok := args["body"].(string); ok { opt.Body = new(body) } diff --git a/operation/pull/pull_test.go b/operation/pull/pull_test.go index cdcb081..56fb48d 100644 --- a/operation/pull/pull_test.go +++ b/operation/pull/pull_test.go @@ -254,6 +254,236 @@ func Test_mergePullRequestFn(t *testing.T) { } } +func Test_applyDraftPrefix(t *testing.T) { + tests := []struct { + name string + title string + isDraft bool + want string + }{ + {"add prefix", "my feature", true, "WIP: my feature"}, + {"already prefixed WIP:", "WIP: my feature", true, "WIP: my feature"}, + {"already prefixed WIP: no space", "WIP:my feature", true, "WIP:my feature"}, + {"already prefixed [WIP]", "[WIP] my feature", true, "[WIP] my feature"}, + {"already prefixed case insensitive", "wip: my feature", true, "wip: my feature"}, + {"already prefixed [wip]", "[wip] my feature", true, "[wip] my feature"}, + {"remove WIP: prefix", "WIP: my feature", false, "my feature"}, + {"remove WIP: no space", "WIP:my feature", false, "my feature"}, + {"remove [WIP] prefix", "[WIP] my feature", false, "my feature"}, + {"remove [wip] prefix", "[wip] my feature", false, "my feature"}, + {"remove wip: lowercase", "wip: my feature", false, "my feature"}, + {"no prefix not draft", "my feature", false, "my feature"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := applyDraftPrefix(tt.title, tt.isDraft) + if got != tt.want { + t.Fatalf("applyDraftPrefix(%q, %v) = %q, want %q", tt.title, tt.isDraft, got, tt.want) + } + }) + } +} + +func Test_createPullRequestFn_draft(t *testing.T) { + const ( + owner = "octo" + repo = "demo" + ) + + tests := []struct { + name string + title string + draft any // bool or nil (omitted) + wantTitle string + }{ + {"draft true", "my feature", true, "WIP: my feature"}, + {"draft false strips WIP:", "WIP: my feature", false, "my feature"}, + {"draft false strips [WIP]", "[WIP] my feature", false, "my feature"}, + {"draft omitted preserves title", "WIP: my feature", nil, "WIP: my feature"}, + {"draft true already prefixed", "WIP: my feature", true, "WIP: my feature"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var ( + mu sync.Mutex + gotBody map[string]any + ) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/version": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version":"1.12.0"}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"private":false}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s/pulls", owner, repo): + mu.Lock() + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + gotBody = body + mu.Unlock() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(fmt.Appendf(nil, `{"number":1,"title":%q,"state":"open"}`, body["title"])) + default: + http.NotFound(w, r) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + origHost := flag.Host + origToken := flag.Token + origVersion := flag.Version + flag.Host = server.URL + flag.Token = "" + flag.Version = "test" + defer func() { + flag.Host = origHost + flag.Token = origToken + flag.Version = origVersion + }() + + args := map[string]any{ + "owner": owner, + "repo": repo, + "title": tc.title, + "body": "test body", + "head": "feature", + "base": "main", + } + if tc.draft != nil { + args["draft"] = tc.draft + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: args, + }, + } + + _, err := createPullRequestFn(context.Background(), req) + if err != nil { + t.Fatalf("createPullRequestFn() error = %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if gotBody["title"] != tc.wantTitle { + t.Fatalf("expected title %q, got %v", tc.wantTitle, gotBody["title"]) + } + }) + } +} + +func Test_editPullRequestFn_draft(t *testing.T) { + const ( + owner = "octo" + repo = "demo" + index = 7 + ) + + tests := []struct { + name string + title string // title arg passed to the tool; empty means omitted + draft any + wantTitle string + }{ + {"set draft with title", "my feature", true, "WIP: my feature"}, + {"unset draft with title", "WIP: my feature", false, "my feature"}, + {"set draft without title fetches current", "", true, "WIP: existing title"}, + {"unset draft without title fetches current", "", false, "existing title"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var ( + mu sync.Mutex + gotBody map[string]any + ) + + prPath := fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/version": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version":"1.12.0"}`)) + case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"private":false}`)) + case prPath: + w.Header().Set("Content-Type", "application/json") + if r.Method == http.MethodGet { + // Auto-fetch: return the existing PR with its current title + _, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"existing title","state":"open"}`, index)) + return + } + mu.Lock() + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + gotBody = body + mu.Unlock() + title := "existing title" + if s, ok := body["title"].(string); ok { + title = s + } + _, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":%q,"state":"open"}`, index, title)) + default: + http.NotFound(w, r) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + origHost := flag.Host + origToken := flag.Token + origVersion := flag.Version + flag.Host = server.URL + flag.Token = "" + flag.Version = "test" + defer func() { + flag.Host = origHost + flag.Token = origToken + flag.Version = origVersion + }() + + args := map[string]any{ + "owner": owner, + "repo": repo, + "index": float64(index), + } + if tc.title != "" { + args["title"] = tc.title + } + if tc.draft != nil { + args["draft"] = tc.draft + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: args, + }, + } + + _, err := editPullRequestFn(context.Background(), req) + if err != nil { + t.Fatalf("editPullRequestFn() error = %v", err) + } + + mu.Lock() + defer mu.Unlock() + + if gotBody["title"] != tc.wantTitle { + t.Fatalf("expected title %q, got %v", tc.wantTitle, gotBody["title"]) + } + }) + } +} + func Test_getPullRequestDiffFn(t *testing.T) { const ( owner = "octo"