diff --git a/operation/operation.go b/operation/operation.go index 75af00a..141cce3 100644 --- a/operation/operation.go +++ b/operation/operation.go @@ -31,23 +31,6 @@ import ( var mcpServer *server.MCPServer -type noBodyContentTypeResponseWriter struct { - http.ResponseWriter -} - -func (w *noBodyContentTypeResponseWriter) WriteHeader(statusCode int) { - if (statusCode == http.StatusAccepted || statusCode == http.StatusNoContent) && w.Header().Get("Content-Type") == "" { - w.Header().Set("Content-Type", "application/json") - } - w.ResponseWriter.WriteHeader(statusCode) -} - -func withNoBodyContentType(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handler.ServeHTTP(&noBodyContentTypeResponseWriter{ResponseWriter: w}, r) - }) -} - func RegisterTool(s *server.MCPServer) { // User Tool s.AddTools(user.Tool.Tools()...) @@ -130,16 +113,12 @@ func Run() error { return err } case "http": - mux := http.NewServeMux() - httpServer := &http.Server{Handler: mux} - streamableHTTPServer := server.NewStreamableHTTPServer( + httpServer := server.NewStreamableHTTPServer( mcpServer, server.WithLogger(log.New()), server.WithHeartbeatInterval(30*time.Second), server.WithHTTPContextFunc(getContextWithToken), - server.WithStreamableHTTPServer(httpServer), ) - mux.Handle("/mcp", withNoBodyContentType(streamableHTTPServer)) log.Infof("Gitea MCP HTTP server listening on :%d", flag.Port) // Graceful shutdown setup @@ -152,13 +131,13 @@ func Run() error { log.Infof("Shutdown signal received, gracefully stopping HTTP server...") shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := streamableHTTPServer.Shutdown(shutdownCtx); err != nil { + if err := httpServer.Shutdown(shutdownCtx); err != nil { log.Errorf("HTTP server shutdown error: %v", err) } close(shutdownDone) }() - if err := streamableHTTPServer.Start(fmt.Sprintf(":%d", flag.Port)); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := httpServer.Start(fmt.Sprintf(":%d", flag.Port)); err != nil && !errors.Is(err, http.ErrServerClosed) { return err } <-shutdownDone // Wait for shutdown to finish diff --git a/operation/operation_http_headers_test.go b/operation/operation_http_headers_test.go deleted file mode 100644 index 2e7f0b6..0000000 --- a/operation/operation_http_headers_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package operation - -import ( - "net/http" - "net/http/httptest" - "testing" -) - -func TestWithNoBodyContentTypeAddsContentTypeForAcceptedAndNoContent(t *testing.T) { - tests := []struct { - name string - status int - }{ - { - name: "accepted", - status: http.StatusAccepted, - }, - { - name: "no_content", - status: http.StatusNoContent, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - handler := withNoBodyContentType(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(tc.status) - })) - - req := httptest.NewRequest(http.MethodPost, "/mcp", nil) - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - if rr.Code != tc.status { - t.Fatalf("expected status %d, got %d", tc.status, rr.Code) - } - if got := rr.Header().Get("Content-Type"); got != "application/json" { - t.Fatalf("expected Content-Type application/json, got %q", got) - } - }) - } -} - -func TestWithNoBodyContentTypePreservesExistingContentType(t *testing.T) { - handler := withNoBodyContentType(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(http.StatusAccepted) - })) - - req := httptest.NewRequest(http.MethodPost, "/mcp", nil) - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - if rr.Code != http.StatusAccepted { - t.Fatalf("expected status %d, got %d", http.StatusAccepted, rr.Code) - } - if got := rr.Header().Get("Content-Type"); got != "text/plain" { - t.Fatalf("expected Content-Type text/plain, got %q", got) - } -} - -func TestWithNoBodyContentTypeDoesNotModifyOtherStatusCodes(t *testing.T) { - handler := withNoBodyContentType(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - req := httptest.NewRequest(http.MethodPost, "/mcp", nil) - rr := httptest.NewRecorder() - handler.ServeHTTP(rr, req) - - if rr.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) - } - if got := rr.Header().Get("Content-Type"); got != "" { - t.Fatalf("expected empty Content-Type, got %q", got) - } -}