mirror of
https://gitea.com/gitea/gitea-mcp.git
synced 2026-03-23 21:35:11 +00:00
## Summary The Gitea API has no native `draft` field on `CreatePullRequestOption` or `EditPullRequestOption`. Instead, Gitea treats PRs whose title starts with a WIP prefix (e.g. `WIP:`, `[WIP]`) as drafts. This adds a `draft` boolean parameter to the `pull_request_write` tool so MCP clients can create/update draft PRs without knowing about the WIP prefix convention. ## Changes - Add `draft` boolean parameter to `PullRequestWriteTool` schema, supported on `create` and `update` methods - Add `applyDraftPrefix()` helper that handles both default Gitea WIP prefixes (`WIP:`, `[WIP]`) case-insensitively - When `draft=true` and no prefix exists, prepend `WIP: `; when a prefix already exists, preserve the title as-is (no normalization) - When `draft=false`, strip any recognized WIP prefix - On `update`, if `draft` is set without `title`, auto-fetch the current PR title via GET - Add tests: 12 unit tests for `applyDraftPrefix`, 5 integration tests for create, 4 for edit --------- Co-authored-by: tomholford <tomholford@users.noreply.github.com> Reviewed-on: https://gitea.com/gitea/gitea-mcp/pulls/159 Reviewed-by: silverwind <2021+silverwind@noreply.gitea.com> Co-authored-by: tomholford <128995+tomholford@noreply.gitea.com> Co-committed-by: tomholford <128995+tomholford@noreply.gitea.com>
610 lines
16 KiB
Go
610 lines
16 KiB
Go
package pull
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
|
|
"gitea.com/gitea/gitea-mcp/pkg/flag"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
)
|
|
|
|
func Test_editPullRequestFn(t *testing.T) {
|
|
const (
|
|
owner = "octo"
|
|
repo = "demo"
|
|
index = 7
|
|
)
|
|
|
|
indexInputs := []struct {
|
|
name string
|
|
val any
|
|
}{
|
|
{"float64", float64(index)},
|
|
{"string", "7"},
|
|
}
|
|
|
|
for _, ii := range indexInputs {
|
|
t.Run(ii.name, func(t *testing.T) {
|
|
var (
|
|
mu sync.Mutex
|
|
gotMethod string
|
|
gotPath string
|
|
gotBody map[string]any
|
|
)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/api/v1/version":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
|
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"private":false}`))
|
|
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index):
|
|
mu.Lock()
|
|
gotMethod = r.Method
|
|
gotPath = r.URL.Path
|
|
var body map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&body)
|
|
gotBody = body
|
|
mu.Unlock()
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"%s","state":"open"}`, index, body["title"]))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
})
|
|
|
|
server := httptest.NewServer(handler)
|
|
defer server.Close()
|
|
|
|
origHost := flag.Host
|
|
origToken := flag.Token
|
|
origVersion := flag.Version
|
|
flag.Host = server.URL
|
|
flag.Token = ""
|
|
flag.Version = "test"
|
|
defer func() {
|
|
flag.Host = origHost
|
|
flag.Token = origToken
|
|
flag.Version = origVersion
|
|
}()
|
|
|
|
req := mcp.CallToolRequest{
|
|
Params: mcp.CallToolParams{
|
|
Arguments: map[string]any{
|
|
"owner": owner,
|
|
"repo": repo,
|
|
"index": ii.val,
|
|
"title": "WIP: my feature",
|
|
"state": "open",
|
|
},
|
|
},
|
|
}
|
|
|
|
result, err := editPullRequestFn(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("editPullRequestFn() error = %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if gotMethod != http.MethodPatch {
|
|
t.Fatalf("expected PATCH request, got %s", gotMethod)
|
|
}
|
|
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index) {
|
|
t.Fatalf("unexpected path: %s", gotPath)
|
|
}
|
|
if gotBody["title"] != "WIP: my feature" {
|
|
t.Fatalf("expected title 'WIP: my feature', got %v", gotBody["title"])
|
|
}
|
|
if gotBody["state"] != "open" {
|
|
t.Fatalf("expected state 'open', got %v", gotBody["state"])
|
|
}
|
|
|
|
if len(result.Content) == 0 {
|
|
t.Fatalf("expected content in result")
|
|
}
|
|
textContent, ok := mcp.AsTextContent(result.Content[0])
|
|
if !ok {
|
|
t.Fatalf("expected text content, got %T", result.Content[0])
|
|
}
|
|
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
|
|
t.Fatalf("unmarshal result text: %v", err)
|
|
}
|
|
if got := parsed["title"].(string); got != "WIP: my feature" {
|
|
t.Fatalf("result title = %q, want %q", got, "WIP: my feature")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_mergePullRequestFn(t *testing.T) {
|
|
const (
|
|
owner = "octo"
|
|
repo = "demo"
|
|
index = 5
|
|
)
|
|
|
|
indexInputs := []struct {
|
|
name string
|
|
val any
|
|
}{
|
|
{"float64", float64(index)},
|
|
{"string", "5"},
|
|
}
|
|
|
|
for _, ii := range indexInputs {
|
|
t.Run(ii.name, func(t *testing.T) {
|
|
var (
|
|
mu sync.Mutex
|
|
gotMethod string
|
|
gotPath string
|
|
gotBody map[string]any
|
|
)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/api/v1/version":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
|
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"private":false}`))
|
|
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index):
|
|
mu.Lock()
|
|
gotMethod = r.Method
|
|
gotPath = r.URL.Path
|
|
var body map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&body)
|
|
gotBody = body
|
|
mu.Unlock()
|
|
w.WriteHeader(http.StatusOK)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
})
|
|
|
|
server := httptest.NewServer(handler)
|
|
defer server.Close()
|
|
|
|
origHost := flag.Host
|
|
origToken := flag.Token
|
|
origVersion := flag.Version
|
|
flag.Host = server.URL
|
|
flag.Token = ""
|
|
flag.Version = "test"
|
|
defer func() {
|
|
flag.Host = origHost
|
|
flag.Token = origToken
|
|
flag.Version = origVersion
|
|
}()
|
|
|
|
req := mcp.CallToolRequest{
|
|
Params: mcp.CallToolParams{
|
|
Arguments: map[string]any{
|
|
"owner": owner,
|
|
"repo": repo,
|
|
"index": ii.val,
|
|
"merge_style": "squash",
|
|
"title": "feat: my squashed commit",
|
|
"message": "Squash merge of PR #5",
|
|
"delete_branch": true,
|
|
},
|
|
},
|
|
}
|
|
|
|
result, err := mergePullRequestFn(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("mergePullRequestFn() error = %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if gotMethod != http.MethodPost {
|
|
t.Fatalf("expected POST request, got %s", gotMethod)
|
|
}
|
|
if gotPath != fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index) {
|
|
t.Fatalf("unexpected path: %s", gotPath)
|
|
}
|
|
if gotBody["Do"] != "squash" {
|
|
t.Fatalf("expected Do 'squash', got %v", gotBody["Do"])
|
|
}
|
|
if gotBody["MergeTitleField"] != "feat: my squashed commit" {
|
|
t.Fatalf("expected MergeTitleField 'feat: my squashed commit', got %v", gotBody["MergeTitleField"])
|
|
}
|
|
if gotBody["MergeMessageField"] != "Squash merge of PR #5" {
|
|
t.Fatalf("expected MergeMessageField 'Squash merge of PR #5', got %v", gotBody["MergeMessageField"])
|
|
}
|
|
if gotBody["delete_branch_after_merge"] != true {
|
|
t.Fatalf("expected delete_branch_after_merge true, got %v", gotBody["delete_branch_after_merge"])
|
|
}
|
|
|
|
if len(result.Content) == 0 {
|
|
t.Fatalf("expected content in result")
|
|
}
|
|
textContent, ok := mcp.AsTextContent(result.Content[0])
|
|
if !ok {
|
|
t.Fatalf("expected text content, got %T", result.Content[0])
|
|
}
|
|
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
|
|
t.Fatalf("unmarshal result text: %v", err)
|
|
}
|
|
if parsed["merged"] != true {
|
|
t.Fatalf("expected merged=true, got %v", parsed["merged"])
|
|
}
|
|
if parsed["merge_style"] != "squash" {
|
|
t.Fatalf("expected merge_style 'squash', got %v", parsed["merge_style"])
|
|
}
|
|
if parsed["branch_deleted"] != true {
|
|
t.Fatalf("expected branch_deleted=true, got %v", parsed["branch_deleted"])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_applyDraftPrefix(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
title string
|
|
isDraft bool
|
|
want string
|
|
}{
|
|
{"add prefix", "my feature", true, "WIP: my feature"},
|
|
{"already prefixed WIP:", "WIP: my feature", true, "WIP: my feature"},
|
|
{"already prefixed WIP: no space", "WIP:my feature", true, "WIP:my feature"},
|
|
{"already prefixed [WIP]", "[WIP] my feature", true, "[WIP] my feature"},
|
|
{"already prefixed case insensitive", "wip: my feature", true, "wip: my feature"},
|
|
{"already prefixed [wip]", "[wip] my feature", true, "[wip] my feature"},
|
|
{"remove WIP: prefix", "WIP: my feature", false, "my feature"},
|
|
{"remove WIP: no space", "WIP:my feature", false, "my feature"},
|
|
{"remove [WIP] prefix", "[WIP] my feature", false, "my feature"},
|
|
{"remove [wip] prefix", "[wip] my feature", false, "my feature"},
|
|
{"remove wip: lowercase", "wip: my feature", false, "my feature"},
|
|
{"no prefix not draft", "my feature", false, "my feature"},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := applyDraftPrefix(tt.title, tt.isDraft)
|
|
if got != tt.want {
|
|
t.Fatalf("applyDraftPrefix(%q, %v) = %q, want %q", tt.title, tt.isDraft, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_createPullRequestFn_draft(t *testing.T) {
|
|
const (
|
|
owner = "octo"
|
|
repo = "demo"
|
|
)
|
|
|
|
tests := []struct {
|
|
name string
|
|
title string
|
|
draft any // bool or nil (omitted)
|
|
wantTitle string
|
|
}{
|
|
{"draft true", "my feature", true, "WIP: my feature"},
|
|
{"draft false strips WIP:", "WIP: my feature", false, "my feature"},
|
|
{"draft false strips [WIP]", "[WIP] my feature", false, "my feature"},
|
|
{"draft omitted preserves title", "WIP: my feature", nil, "WIP: my feature"},
|
|
{"draft true already prefixed", "WIP: my feature", true, "WIP: my feature"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
var (
|
|
mu sync.Mutex
|
|
gotBody map[string]any
|
|
)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/api/v1/version":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
|
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"private":false}`))
|
|
case fmt.Sprintf("/api/v1/repos/%s/%s/pulls", owner, repo):
|
|
mu.Lock()
|
|
var body map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&body)
|
|
gotBody = body
|
|
mu.Unlock()
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write(fmt.Appendf(nil, `{"number":1,"title":%q,"state":"open"}`, body["title"]))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
})
|
|
|
|
server := httptest.NewServer(handler)
|
|
defer server.Close()
|
|
|
|
origHost := flag.Host
|
|
origToken := flag.Token
|
|
origVersion := flag.Version
|
|
flag.Host = server.URL
|
|
flag.Token = ""
|
|
flag.Version = "test"
|
|
defer func() {
|
|
flag.Host = origHost
|
|
flag.Token = origToken
|
|
flag.Version = origVersion
|
|
}()
|
|
|
|
args := map[string]any{
|
|
"owner": owner,
|
|
"repo": repo,
|
|
"title": tc.title,
|
|
"body": "test body",
|
|
"head": "feature",
|
|
"base": "main",
|
|
}
|
|
if tc.draft != nil {
|
|
args["draft"] = tc.draft
|
|
}
|
|
|
|
req := mcp.CallToolRequest{
|
|
Params: mcp.CallToolParams{
|
|
Arguments: args,
|
|
},
|
|
}
|
|
|
|
_, err := createPullRequestFn(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("createPullRequestFn() error = %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if gotBody["title"] != tc.wantTitle {
|
|
t.Fatalf("expected title %q, got %v", tc.wantTitle, gotBody["title"])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_editPullRequestFn_draft(t *testing.T) {
|
|
const (
|
|
owner = "octo"
|
|
repo = "demo"
|
|
index = 7
|
|
)
|
|
|
|
tests := []struct {
|
|
name string
|
|
title string // title arg passed to the tool; empty means omitted
|
|
draft any
|
|
wantTitle string
|
|
}{
|
|
{"set draft with title", "my feature", true, "WIP: my feature"},
|
|
{"unset draft with title", "WIP: my feature", false, "my feature"},
|
|
{"set draft without title fetches current", "", true, "WIP: existing title"},
|
|
{"unset draft without title fetches current", "", false, "existing title"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
var (
|
|
mu sync.Mutex
|
|
gotBody map[string]any
|
|
)
|
|
|
|
prPath := fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d", owner, repo, index)
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/api/v1/version":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
|
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"private":false}`))
|
|
case prPath:
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if r.Method == http.MethodGet {
|
|
// Auto-fetch: return the existing PR with its current title
|
|
_, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":"existing title","state":"open"}`, index))
|
|
return
|
|
}
|
|
mu.Lock()
|
|
var body map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&body)
|
|
gotBody = body
|
|
mu.Unlock()
|
|
title := "existing title"
|
|
if s, ok := body["title"].(string); ok {
|
|
title = s
|
|
}
|
|
_, _ = w.Write(fmt.Appendf(nil, `{"number":%d,"title":%q,"state":"open"}`, index, title))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
})
|
|
|
|
server := httptest.NewServer(handler)
|
|
defer server.Close()
|
|
|
|
origHost := flag.Host
|
|
origToken := flag.Token
|
|
origVersion := flag.Version
|
|
flag.Host = server.URL
|
|
flag.Token = ""
|
|
flag.Version = "test"
|
|
defer func() {
|
|
flag.Host = origHost
|
|
flag.Token = origToken
|
|
flag.Version = origVersion
|
|
}()
|
|
|
|
args := map[string]any{
|
|
"owner": owner,
|
|
"repo": repo,
|
|
"index": float64(index),
|
|
}
|
|
if tc.title != "" {
|
|
args["title"] = tc.title
|
|
}
|
|
if tc.draft != nil {
|
|
args["draft"] = tc.draft
|
|
}
|
|
|
|
req := mcp.CallToolRequest{
|
|
Params: mcp.CallToolParams{
|
|
Arguments: args,
|
|
},
|
|
}
|
|
|
|
_, err := editPullRequestFn(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("editPullRequestFn() error = %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if gotBody["title"] != tc.wantTitle {
|
|
t.Fatalf("expected title %q, got %v", tc.wantTitle, gotBody["title"])
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_getPullRequestDiffFn(t *testing.T) {
|
|
const (
|
|
owner = "octo"
|
|
repo = "demo"
|
|
index = 12
|
|
diffRaw = "diff --git a/file.txt b/file.txt\n+line\n"
|
|
)
|
|
|
|
indexInputs := []struct {
|
|
name string
|
|
val any
|
|
}{
|
|
{"float64", float64(index)},
|
|
{"string", "12"},
|
|
}
|
|
|
|
for _, ii := range indexInputs {
|
|
t.Run(ii.name, func(t *testing.T) {
|
|
var (
|
|
mu sync.Mutex
|
|
diffRequested bool
|
|
binaryValue string
|
|
)
|
|
errCh := make(chan error, 1)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/api/v1/version":
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"version":"1.12.0"}`))
|
|
case fmt.Sprintf("/api/v1/repos/%s/%s", owner, repo):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"private":false}`))
|
|
case fmt.Sprintf("/%s/%s/pulls/%d.diff", owner, repo, index):
|
|
if r.Method != http.MethodGet {
|
|
select {
|
|
case errCh <- fmt.Errorf("unexpected method: %s", r.Method):
|
|
default:
|
|
}
|
|
}
|
|
mu.Lock()
|
|
diffRequested = true
|
|
binaryValue = r.URL.Query().Get("binary")
|
|
mu.Unlock()
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
_, _ = w.Write([]byte(diffRaw))
|
|
default:
|
|
select {
|
|
case errCh <- fmt.Errorf("unexpected request path: %s", r.URL.Path):
|
|
default:
|
|
}
|
|
}
|
|
})
|
|
|
|
server := httptest.NewServer(handler)
|
|
defer server.Close()
|
|
|
|
origHost := flag.Host
|
|
origToken := flag.Token
|
|
origVersion := flag.Version
|
|
flag.Host = server.URL
|
|
flag.Token = ""
|
|
flag.Version = "test"
|
|
defer func() {
|
|
flag.Host = origHost
|
|
flag.Token = origToken
|
|
flag.Version = origVersion
|
|
}()
|
|
|
|
req := mcp.CallToolRequest{
|
|
Params: mcp.CallToolParams{
|
|
Arguments: map[string]any{
|
|
"owner": owner,
|
|
"repo": repo,
|
|
"index": ii.val,
|
|
"binary": true,
|
|
},
|
|
},
|
|
}
|
|
|
|
result, err := getPullRequestDiffFn(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("getPullRequestDiffFn() error = %v", err)
|
|
}
|
|
|
|
select {
|
|
case reqErr := <-errCh:
|
|
t.Fatalf("handler error: %v", reqErr)
|
|
default:
|
|
}
|
|
|
|
mu.Lock()
|
|
requested := diffRequested
|
|
gotBinary := binaryValue
|
|
mu.Unlock()
|
|
|
|
if !requested {
|
|
t.Fatalf("expected diff request to be made")
|
|
}
|
|
if gotBinary != "true" {
|
|
t.Fatalf("expected binary=true query param, got %q", gotBinary)
|
|
}
|
|
|
|
if len(result.Content) == 0 {
|
|
t.Fatalf("expected content in result")
|
|
}
|
|
|
|
textContent, ok := mcp.AsTextContent(result.Content[0])
|
|
if !ok {
|
|
t.Fatalf("expected text content, got %T", result.Content[0])
|
|
}
|
|
|
|
// The diff response is now a plain string
|
|
var parsed string
|
|
if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
|
|
t.Fatalf("unmarshal result text: %v", err)
|
|
}
|
|
if parsed != diffRaw {
|
|
t.Fatalf("diff = %q, want %q", parsed, diffRaw)
|
|
}
|
|
})
|
|
}
|
|
}
|