mirror of
https://gitea.com/gitea/gitea-mcp.git
synced 2026-03-23 13:25:13 +00:00
feat(pull): add draft parameter for creating/updating draft PRs (#159)
## Summary The Gitea API has no native `draft` field on `CreatePullRequestOption` or `EditPullRequestOption`. Instead, Gitea treats PRs whose title starts with a WIP prefix (e.g. `WIP:`, `[WIP]`) as drafts. This adds a `draft` boolean parameter to the `pull_request_write` tool so MCP clients can create/update draft PRs without knowing about the WIP prefix convention. ## Changes - Add `draft` boolean parameter to `PullRequestWriteTool` schema, supported on `create` and `update` methods - Add `applyDraftPrefix()` helper that handles both default Gitea WIP prefixes (`WIP:`, `[WIP]`) case-insensitively - When `draft=true` and no prefix exists, prepend `WIP: `; when a prefix already exists, preserve the title as-is (no normalization) - When `draft=false`, strip any recognized WIP prefix - On `update`, if `draft` is set without `title`, auto-fetch the current PR title via GET - Add tests: 12 unit tests for `applyDraftPrefix`, 5 integration tests for create, 4 for edit --------- Co-authored-by: tomholford <tomholford@users.noreply.github.com> Reviewed-on: https://gitea.com/gitea/gitea-mcp/pulls/159 Reviewed-by: silverwind <2021+silverwind@noreply.gitea.com> Co-authored-by: tomholford <128995+tomholford@noreply.gitea.com> Co-committed-by: tomholford <128995+tomholford@noreply.gitea.com>
This commit is contained in:
@@ -3,6 +3,7 @@ package pull
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gitea.com/gitea/gitea-mcp/pkg/gitea"
|
"gitea.com/gitea/gitea-mcp/pkg/gitea"
|
||||||
"gitea.com/gitea/gitea-mcp/pkg/log"
|
"gitea.com/gitea/gitea-mcp/pkg/log"
|
||||||
@@ -71,6 +72,7 @@ var (
|
|||||||
mcp.WithBoolean("delete_branch", mcp.Description("delete branch after merge (for 'merge')")),
|
mcp.WithBoolean("delete_branch", mcp.Description("delete branch after merge (for 'merge')")),
|
||||||
mcp.WithArray("reviewers", mcp.Description("reviewer usernames (for 'add_reviewers', 'remove_reviewers')"), mcp.Items(map[string]any{"type": "string"})),
|
mcp.WithArray("reviewers", mcp.Description("reviewer usernames (for 'add_reviewers', 'remove_reviewers')"), mcp.Items(map[string]any{"type": "string"})),
|
||||||
mcp.WithArray("team_reviewers", mcp.Description("team reviewer names (for 'add_reviewers', 'remove_reviewers')"), mcp.Items(map[string]any{"type": "string"})),
|
mcp.WithArray("team_reviewers", mcp.Description("team reviewer names (for 'add_reviewers', 'remove_reviewers')"), mcp.Items(map[string]any{"type": "string"})),
|
||||||
|
mcp.WithBoolean("draft", mcp.Description("mark PR as draft (for 'create', 'update'). Gitea uses a 'WIP: ' title prefix for drafts.")),
|
||||||
)
|
)
|
||||||
|
|
||||||
PullRequestReviewWriteTool = mcp.NewTool(
|
PullRequestReviewWriteTool = mcp.NewTool(
|
||||||
@@ -271,6 +273,28 @@ func listRepoPullRequestsFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.
|
|||||||
return to.TextResult(slimPullRequests(pullRequests))
|
return to.TextResult(slimPullRequests(pullRequests))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// defaultWIPPrefixes are the default Gitea title prefixes that mark a PR as
|
||||||
|
// work-in-progress / draft. Gitea matches these case-insensitively.
|
||||||
|
var defaultWIPPrefixes = []string{"WIP:", "[WIP]"}
|
||||||
|
|
||||||
|
// applyDraftPrefix adds or removes a WIP title prefix that Gitea uses to mark
|
||||||
|
// pull requests as drafts. When the title already carries a recognized prefix
|
||||||
|
// and isDraft is true, the title is returned unchanged to avoid normalization.
|
||||||
|
func applyDraftPrefix(title string, isDraft bool) string {
|
||||||
|
for _, prefix := range defaultWIPPrefixes {
|
||||||
|
if len(title) >= len(prefix) && strings.EqualFold(title[:len(prefix)], prefix) {
|
||||||
|
if isDraft {
|
||||||
|
return title
|
||||||
|
}
|
||||||
|
return strings.TrimLeft(title[len(prefix):], " ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isDraft {
|
||||||
|
return "WIP: " + title
|
||||||
|
}
|
||||||
|
return title
|
||||||
|
}
|
||||||
|
|
||||||
func createPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
func createPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
log.Debugf("Called createPullRequestFn")
|
log.Debugf("Called createPullRequestFn")
|
||||||
args := req.GetArguments()
|
args := req.GetArguments()
|
||||||
@@ -298,6 +322,11 @@ func createPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return to.ErrorResult(err)
|
return to.ErrorResult(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if draft, ok := args["draft"].(bool); ok {
|
||||||
|
title = applyDraftPrefix(title, draft)
|
||||||
|
}
|
||||||
|
|
||||||
client, err := gitea.ClientFromContext(ctx)
|
client, err := gitea.ClientFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
||||||
@@ -774,6 +803,22 @@ func editPullRequestFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT
|
|||||||
if title, ok := args["title"].(string); ok {
|
if title, ok := args["title"].(string); ok {
|
||||||
opt.Title = title
|
opt.Title = title
|
||||||
}
|
}
|
||||||
|
if draft, ok := args["draft"].(bool); ok {
|
||||||
|
if opt.Title == "" {
|
||||||
|
// Fetch current title so the caller doesn't have to provide it
|
||||||
|
// just to toggle draft status.
|
||||||
|
client, err := gitea.ClientFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
||||||
|
}
|
||||||
|
pr, _, err := client.GetPullRequest(owner, repo, index)
|
||||||
|
if err != nil {
|
||||||
|
return to.ErrorResult(fmt.Errorf("get %v/%v/pr/%v err: %v", owner, repo, index, err))
|
||||||
|
}
|
||||||
|
opt.Title = pr.Title
|
||||||
|
}
|
||||||
|
opt.Title = applyDraftPrefix(opt.Title, draft)
|
||||||
|
}
|
||||||
if body, ok := args["body"].(string); ok {
|
if body, ok := args["body"].(string); ok {
|
||||||
opt.Body = new(body)
|
opt.Body = new(body)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -254,6 +254,236 @@ func Test_mergePullRequestFn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_applyDraftPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
title string
|
||||||
|
isDraft bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"add prefix", "my feature", true, "WIP: my feature"},
|
||||||
|
{"already prefixed WIP:", "WIP: my feature", true, "WIP: my feature"},
|
||||||
|
{"already prefixed WIP: no space", "WIP:my feature", true, "WIP:my feature"},
|
||||||
|
{"already prefixed [WIP]", "[WIP] my feature", true, "[WIP] my feature"},
|
||||||
|
{"already prefixed case insensitive", "wip: my feature", true, "wip: my feature"},
|
||||||
|
{"already prefixed [wip]", "[wip] my feature", true, "[wip] my feature"},
|
||||||
|
{"remove WIP: prefix", "WIP: my feature", false, "my feature"},
|
||||||
|
{"remove WIP: no space", "WIP:my feature", false, "my feature"},
|
||||||
|
{"remove [WIP] prefix", "[WIP] my feature", false, "my feature"},
|
||||||
|
{"remove [wip] prefix", "[wip] my feature", false, "my feature"},
|
||||||
|
{"remove wip: lowercase", "wip: my feature", false, "my feature"},
|
||||||
|
{"no prefix not draft", "my feature", false, "my feature"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := applyDraftPrefix(tt.title, tt.isDraft)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("applyDraftPrefix(%q, %v) = %q, want %q", tt.title, tt.isDraft, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_createPullRequestFn_draft(t *testing.T) {
|
||||||
|
const (
|
||||||
|
owner = "octo"
|
||||||
|
repo = "demo"
|
||||||
|
)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
title string
|
||||||
|
draft any // bool or nil (omitted)
|
||||||
|
wantTitle string
|
||||||
|
}{
|
||||||
|
{"draft true", "my feature", true, "WIP: my feature"},
|
||||||
|
{"draft false strips WIP:", "WIP: my feature", false, "my feature"},
|
||||||
|
{"draft false strips [WIP]", "[WIP] my feature", false, "my feature"},
|
||||||
|
{"draft omitted preserves title", "WIP: my feature", nil, "WIP: my feature"},
|
||||||
|
{"draft true already prefixed", "WIP: my feature", true, "WIP: my feature"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
gotBody map[string]any
|
||||||
|
)
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v1/version":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
||||||
|
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"private":false}`))
|
||||||
|
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls", owner, repo):
|
||||||
|
mu.Lock()
|
||||||
|
var body map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&body)
|
||||||
|
gotBody = body
|
||||||
|
mu.Unlock()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write(fmt.Appendf(nil, `{"number":1,"title":%q,"state":"open"}`, body["title"]))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
server := httptest.NewServer(handler)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
origHost := flag.Host
|
||||||
|
origToken := flag.Token
|
||||||
|
origVersion := flag.Version
|
||||||
|
flag.Host = server.URL
|
||||||
|
flag.Token = ""
|
||||||
|
flag.Version = "test"
|
||||||
|
defer func() {
|
||||||
|
flag.Host = origHost
|
||||||
|
flag.Token = origToken
|
||||||
|
flag.Version = origVersion
|
||||||
|
}()
|
||||||
|
|
||||||
|
args := map[string]any{
|
||||||
|
"owner": owner,
|
||||||
|
"repo": repo,
|
||||||
|
"title": tc.title,
|
||||||
|
"body": "test body",
|
||||||
|
"head": "feature",
|
||||||
|
"base": "main",
|
||||||
|
}
|
||||||
|
if tc.draft != nil {
|
||||||
|
args["draft"] = tc.draft
|
||||||
|
}
|
||||||
|
|
||||||
|
req := mcp.CallToolRequest{
|
||||||
|
Params: mcp.CallToolParams{
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := createPullRequestFn(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("createPullRequestFn() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
if gotBody["title"] != tc.wantTitle {
|
||||||
|
t.Fatalf("expected title %q, got %v", tc.wantTitle, gotBody["title"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_editPullRequestFn_draft(t *testing.T) {
|
||||||
|
const (
|
||||||
|
owner = "octo"
|
||||||
|
repo = "demo"
|
||||||
|
index = 7
|
||||||
|
)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
title string // title arg passed to the tool; empty means omitted
|
||||||
|
draft any
|
||||||
|
wantTitle string
|
||||||
|
}{
|
||||||
|
{"set draft with title", "my feature", true, "WIP: my feature"},
|
||||||
|
{"unset draft with title", "WIP: my feature", false, "my feature"},
|
||||||
|
{"set draft without title fetches current", "", true, "WIP: existing title"},
|
||||||
|
{"unset draft without title fetches current", "", false, "existing title"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
gotBody map[string]any
|
||||||
|
)
|
||||||
|
|
||||||
|
prPath := fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index)
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v1/version":
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
||||||
|
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"private":false}`))
|
||||||
|
case prPath:
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
|
// Auto-fetch: return the existing PR with its current title
|
||||||
|
_, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"existing title","state":"open"}`, index))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
var body map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&body)
|
||||||
|
gotBody = body
|
||||||
|
mu.Unlock()
|
||||||
|
title := "existing title"
|
||||||
|
if s, ok := body["title"].(string); ok {
|
||||||
|
title = s
|
||||||
|
}
|
||||||
|
_, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":%q,"state":"open"}`, index, title))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
server := httptest.NewServer(handler)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
origHost := flag.Host
|
||||||
|
origToken := flag.Token
|
||||||
|
origVersion := flag.Version
|
||||||
|
flag.Host = server.URL
|
||||||
|
flag.Token = ""
|
||||||
|
flag.Version = "test"
|
||||||
|
defer func() {
|
||||||
|
flag.Host = origHost
|
||||||
|
flag.Token = origToken
|
||||||
|
flag.Version = origVersion
|
||||||
|
}()
|
||||||
|
|
||||||
|
args := map[string]any{
|
||||||
|
"owner": owner,
|
||||||
|
"repo": repo,
|
||||||
|
"index": float64(index),
|
||||||
|
}
|
||||||
|
if tc.title != "" {
|
||||||
|
args["title"] = tc.title
|
||||||
|
}
|
||||||
|
if tc.draft != nil {
|
||||||
|
args["draft"] = tc.draft
|
||||||
|
}
|
||||||
|
|
||||||
|
req := mcp.CallToolRequest{
|
||||||
|
Params: mcp.CallToolParams{
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := editPullRequestFn(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("editPullRequestFn() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
if gotBody["title"] != tc.wantTitle {
|
||||||
|
t.Fatalf("expected title %q, got %v", tc.wantTitle, gotBody["title"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_getPullRequestDiffFn(t *testing.T) {
|
func Test_getPullRequestDiffFn(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
owner = "octo"
|
owner = "octo"
|
||||||
|
|||||||
Reference in New Issue
Block a user