mirror of
https://gitea.com/gitea/gitea-mcp.git
synced 2026-03-18 10:55:12 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bdf8f5bb3 | ||
|
|
7bf54b9e83 | ||
|
|
c57e4c2e57 | ||
|
|
22fc663387 | ||
|
|
e0abd256a3 | ||
|
|
73263e74d0 |
7
main.go
7
main.go
@@ -1,6 +1,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"runtime/debug"
|
||||
|
||||
"gitea.com/gitea/gitea-mcp/cmd"
|
||||
"gitea.com/gitea/gitea-mcp/pkg/flag"
|
||||
)
|
||||
@@ -8,6 +10,11 @@ import (
|
||||
var Version = "dev"
|
||||
|
||||
func init() {
|
||||
if Version == "dev" {
|
||||
if info, ok := debug.ReadBuildInfo(); ok && info.Main.Version != "" && info.Main.Version != "(devel)" {
|
||||
Version = info.Main.Version
|
||||
}
|
||||
}
|
||||
flag.Version = Version
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package operation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -136,7 +137,7 @@ func Run() error {
|
||||
close(shutdownDone)
|
||||
}()
|
||||
|
||||
if err := httpServer.Start(fmt.Sprintf(":%d", flag.Port)); err != nil {
|
||||
if err := httpServer.Start(fmt.Sprintf(":%d", flag.Port)); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
<-shutdownDone // Wait for shutdown to finish
|
||||
|
||||
@@ -2,6 +2,7 @@ package repo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gitea.com/gitea/gitea-mcp/pkg/gitea"
|
||||
@@ -18,9 +19,10 @@ import (
|
||||
var Tool = tool.New()
|
||||
|
||||
const (
|
||||
CreateRepoToolName = "create_repo"
|
||||
ForkRepoToolName = "fork_repo"
|
||||
ListMyReposToolName = "list_my_repos"
|
||||
CreateRepoToolName = "create_repo"
|
||||
ForkRepoToolName = "fork_repo"
|
||||
ListMyReposToolName = "list_my_repos"
|
||||
ListOrgReposToolName = "list_org_repos"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -55,6 +57,14 @@ var (
|
||||
mcp.WithNumber("page", mcp.Required(), mcp.Description("Page number"), mcp.DefaultNumber(1), mcp.Min(1)),
|
||||
mcp.WithNumber("perPage", mcp.Required(), mcp.Description("results per page"), mcp.DefaultNumber(30), mcp.Min(1)),
|
||||
)
|
||||
|
||||
ListOrgReposTool = mcp.NewTool(
|
||||
ListOrgReposToolName,
|
||||
mcp.WithDescription("List repositories of an organization"),
|
||||
mcp.WithString("org", mcp.Required(), mcp.Description("Organization name")),
|
||||
mcp.WithNumber("page", mcp.Required(), mcp.Description("Page number"), mcp.DefaultNumber(1), mcp.Min(1)),
|
||||
mcp.WithNumber("pageSize", mcp.Required(), mcp.Description("Page size number"), mcp.DefaultNumber(100), mcp.Min(1)),
|
||||
)
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -70,6 +80,10 @@ func init() {
|
||||
Tool: ListMyReposTool,
|
||||
Handler: ListMyReposFn,
|
||||
})
|
||||
Tool.RegisterRead(server.ServerTool{
|
||||
Tool: ListOrgReposTool,
|
||||
Handler: ListOrgReposFn,
|
||||
})
|
||||
}
|
||||
|
||||
func CreateRepoFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
@@ -178,3 +192,34 @@ func ListMyReposFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolR
|
||||
|
||||
return to.TextResult(slimRepos(repos))
|
||||
}
|
||||
|
||||
func ListOrgReposFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
log.Debugf("Called ListOrgReposFn")
|
||||
org, ok := req.GetArguments()["org"].(string)
|
||||
if !ok {
|
||||
return to.ErrorResult(errors.New("organization name is required"))
|
||||
}
|
||||
page, ok := req.GetArguments()["page"].(float64)
|
||||
if !ok {
|
||||
page = 1
|
||||
}
|
||||
pageSize, ok := req.GetArguments()["pageSize"].(float64)
|
||||
if !ok {
|
||||
pageSize = 100
|
||||
}
|
||||
opt := gitea_sdk.ListOrgReposOptions{
|
||||
ListOptions: gitea_sdk.ListOptions{
|
||||
Page: int(page),
|
||||
PageSize: int(pageSize),
|
||||
},
|
||||
}
|
||||
client, err := gitea.ClientFromContext(ctx)
|
||||
if err != nil {
|
||||
return to.ErrorResult(fmt.Errorf("get gitea client err: %v", err))
|
||||
}
|
||||
repos, _, err := client.ListOrgRepos(org, opt)
|
||||
if err != nil {
|
||||
return to.ErrorResult(fmt.Errorf("list organization '%s' repositories error: %v", org, err))
|
||||
}
|
||||
return to.TextResult(repos)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package wiki
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
@@ -40,7 +41,7 @@ var (
|
||||
mcp.WithString("repo", mcp.Required(), mcp.Description("repository name")),
|
||||
mcp.WithString("pageName", mcp.Description("wiki page name (required for 'update', 'delete')")),
|
||||
mcp.WithString("title", mcp.Description("wiki page title (required for 'create', optional for 'update')")),
|
||||
mcp.WithString("content_base64", mcp.Description("page content, base64 encoded (required for 'create', 'update')")),
|
||||
mcp.WithString("content", mcp.Description("page content (required for 'create', 'update')")),
|
||||
mcp.WithString("message", mcp.Description("commit message")),
|
||||
)
|
||||
)
|
||||
@@ -176,7 +177,7 @@ func createWikiPageFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
contentBase64, err := params.GetString(args, "content_base64")
|
||||
content, err := params.GetString(args, "content")
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
@@ -188,7 +189,7 @@ func createWikiPageFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
|
||||
|
||||
requestBody := map[string]string{
|
||||
"title": title,
|
||||
"content_base64": contentBase64,
|
||||
"content_base64": base64.StdEncoding.EncodeToString([]byte(content)),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
@@ -216,13 +217,13 @@ func updateWikiPageFn(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
contentBase64, err := params.GetString(args, "content_base64")
|
||||
content, err := params.GetString(args, "content")
|
||||
if err != nil {
|
||||
return to.ErrorResult(err)
|
||||
}
|
||||
|
||||
requestBody := map[string]string{
|
||||
"content_base64": contentBase64,
|
||||
"content_base64": base64.StdEncoding.EncodeToString([]byte(content)),
|
||||
}
|
||||
|
||||
// If title is given, use it. Otherwise, keep current page name
|
||||
|
||||
75
operation/wiki/wiki_test.go
Normal file
75
operation/wiki/wiki_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package wiki
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
mcpContext "gitea.com/gitea/gitea-mcp/pkg/context"
|
||||
"gitea.com/gitea/gitea-mcp/pkg/flag"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
func TestWikiWriteBase64Encoding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
content string
|
||||
}{
|
||||
{"create ascii", "create", "Hello, World!"},
|
||||
{"create unicode", "create", "日本語テスト 🎉"},
|
||||
{"create multiline", "create", "line1\nline2\nline3"},
|
||||
{"update ascii", "update", "Updated content"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var gotBody map[string]string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
json.Unmarshal(body, &gotBody)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"title":"test"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
origHost := flag.Host
|
||||
flag.Host = srv.URL
|
||||
defer func() { flag.Host = origHost }()
|
||||
|
||||
ctx := context.WithValue(context.Background(), mcpContext.TokenContextKey, "test-token")
|
||||
|
||||
args := map[string]any{
|
||||
"method": tt.method,
|
||||
"owner": "org",
|
||||
"repo": "repo",
|
||||
"content": tt.content,
|
||||
"pageName": "TestPage",
|
||||
"title": "TestPage",
|
||||
}
|
||||
|
||||
req := mcp.CallToolRequest{}
|
||||
req.Params.Arguments = args
|
||||
|
||||
result, err := wikiWriteFn(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("wikiWriteFn() error: %v", err)
|
||||
}
|
||||
if result.IsError {
|
||||
t.Fatalf("wikiWriteFn() returned error result")
|
||||
}
|
||||
|
||||
got := gotBody["content_base64"]
|
||||
want := base64.StdEncoding.EncodeToString([]byte(tt.content))
|
||||
if got != want {
|
||||
t.Errorf("content_base64 = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package gitea
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -13,7 +14,8 @@ import (
|
||||
|
||||
func NewClient(token string) (*gitea.Client, error) {
|
||||
httpClient := &http.Client{
|
||||
Transport: http.DefaultTransport,
|
||||
Transport: http.DefaultTransport,
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
|
||||
opts := []gitea.ClientOption{
|
||||
@@ -38,6 +40,19 @@ func NewClient(token string) (*gitea.Client, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// checkRedirect prevents Go from silently changing mutating requests (POST, PATCH, etc.)
|
||||
// to GET when following 301/302/303 redirects, which would drop the request body and
|
||||
// make writes appear to succeed when they didn't.
|
||||
func checkRedirect(_ *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
if via[0].Method != http.MethodGet && via[0].Method != http.MethodHead {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClientFromContext(ctx context.Context) (*gitea.Client, error) {
|
||||
token, ok := ctx.Value(mcpContext.TokenContextKey).(string)
|
||||
if !ok {
|
||||
|
||||
120
pkg/gitea/redirect_test.go
Normal file
120
pkg/gitea/redirect_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gitea.com/gitea/gitea-mcp/pkg/flag"
|
||||
)
|
||||
|
||||
func TestCheckRedirect(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
method string
|
||||
wantErr error
|
||||
}{
|
||||
{"allows GET", http.MethodGet, nil},
|
||||
{"allows HEAD", http.MethodHead, nil},
|
||||
{"blocks PATCH", http.MethodPatch, http.ErrUseLastResponse},
|
||||
{"blocks POST", http.MethodPost, http.ErrUseLastResponse},
|
||||
{"blocks PUT", http.MethodPut, http.ErrUseLastResponse},
|
||||
{"blocks DELETE", http.MethodDelete, http.ErrUseLastResponse},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
via := []*http.Request{{Method: tc.method}}
|
||||
err := checkRedirect(nil, via)
|
||||
if err != tc.wantErr {
|
||||
t.Fatalf("expected %v, got %v", tc.wantErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("stops after 10 redirects", func(t *testing.T) {
|
||||
via := make([]*http.Request, 10)
|
||||
for i := range via {
|
||||
via[i] = &http.Request{Method: http.MethodGet}
|
||||
}
|
||||
err := checkRedirect(nil, via)
|
||||
if err == nil || err == http.ErrUseLastResponse {
|
||||
t.Fatalf("expected redirect limit error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDoJSON_RepoRenameRedirect is a regression test for the bug where a PATCH
|
||||
// request to a renamed repo got a 301 redirect, Go's http.Client silently
|
||||
// changed the method to GET, and the write appeared to succeed without error.
|
||||
func TestDoJSON_RepoRenameRedirect(t *testing.T) {
|
||||
// Simulate a Gitea API that returns 301 for the old repo name (like a renamed repo).
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("PATCH /api/v1/repos/owner/old-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/api/v1/repos/owner/new-name/pulls/1", http.StatusMovedPermanently)
|
||||
})
|
||||
mux.HandleFunc("PATCH /api/v1/repos/owner/new-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, `{"id":1,"title":"updated"}`)
|
||||
})
|
||||
mux.HandleFunc("GET /api/v1/repos/owner/new-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, `{"id":1,"title":"not-updated"}`)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
origHost := flag.Host
|
||||
defer func() { flag.Host = origHost }()
|
||||
flag.Host = srv.URL
|
||||
|
||||
var result map[string]any
|
||||
status, err := DoJSON(context.Background(), http.MethodPatch, "repos/owner/old-name/pulls/1", nil, map[string]string{"title": "updated"}, &result)
|
||||
if err != nil {
|
||||
// The redirect should be blocked, returning the 301 response directly.
|
||||
// DoJSON treats non-2xx as an error, which is the correct behavior.
|
||||
if status != http.StatusMovedPermanently {
|
||||
t.Fatalf("expected status 301, got %d (err: %v)", status, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If we reach here without error, the redirect was followed. Verify the
|
||||
// method was preserved (title should be "updated", not "not-updated").
|
||||
title, _ := result["title"].(string)
|
||||
if title == "not-updated" {
|
||||
t.Fatal("PATCH was silently converted to GET on 301 redirect — write was lost")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDoJSON_GETRedirectFollowed verifies that GET requests still follow redirects normally.
|
||||
func TestDoJSON_GETRedirectFollowed(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /api/v1/repos/owner/old-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/api/v1/repos/owner/new-name/pulls/1", http.StatusMovedPermanently)
|
||||
})
|
||||
mux.HandleFunc("GET /api/v1/repos/owner/new-name/pulls/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{"id": 1, "title": "found"})
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
origHost := flag.Host
|
||||
defer func() { flag.Host = origHost }()
|
||||
flag.Host = srv.URL
|
||||
|
||||
var result map[string]any
|
||||
status, err := DoJSON(context.Background(), http.MethodGet, "repos/owner/old-name/pulls/1", nil, nil, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("GET redirect should be followed, got error: %v (status %d)", err, status)
|
||||
}
|
||||
title, _ := result["title"].(string)
|
||||
if title != "found" {
|
||||
t.Fatalf("expected title 'found', got %q", title)
|
||||
}
|
||||
}
|
||||
@@ -44,8 +44,9 @@ func newRESTHTTPClient() *http.Client {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // user-requested insecure mode
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 60 * time.Second,
|
||||
Transport: transport,
|
||||
Timeout: 60 * time.Second,
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user