diff --git a/app/api/controller/pullreq/merge.go b/app/api/controller/pullreq/merge.go index ef22a5ee1..fc253beda 100644 --- a/app/api/controller/pullreq/merge.go +++ b/app/api/controller/pullreq/merge.go @@ -48,22 +48,17 @@ func (c *Controller) Merge( repoRef string, pullreqNum int64, in *MergeInput, -) (types.MergeResponse, error) { - var ( - sha string - pr *types.PullReq - ) - +) (*types.MergeResponse, *types.MergeViolations, error) { method, ok := in.Method.Sanitize() if !ok { - return types.MergeResponse{}, usererror.BadRequest( + return nil, nil, usererror.BadRequest( fmt.Sprintf("wrong merge method type: %s", in.Method)) } in.Method = method targetRepo, err := c.getRepoCheckAccess(ctx, session, repoRef, enum.PermissionRepoEdit) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to acquire access to target repo: %w", err) + return nil, nil, fmt.Errorf("failed to acquire access to target repo: %w", err) } // if two requests for merging comes at the same time then mutex will lock @@ -72,50 +67,50 @@ func (c *Controller) Merge( // pr is already merged. mutex, err := c.newMutexForPR(targetRepo.GitUID, 0) // 0 means locks all PRs for this repo if err != nil { - return types.MergeResponse{}, err + return nil, nil, err } err = mutex.Lock(ctx) if err != nil { - return types.MergeResponse{}, err + return nil, nil, err } defer func() { _ = mutex.Unlock(ctx) }() - pr, err = c.pullreqStore.FindByNumber(ctx, targetRepo.ID, pullreqNum) + pr, err := c.pullreqStore.FindByNumber(ctx, targetRepo.ID, pullreqNum) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to get pull request by number: %w", err) + return nil, nil, fmt.Errorf("failed to get pull request by number: %w", err) } if pr.Merged != nil { - return types.MergeResponse{}, usererror.BadRequest("Pull request already merged") + return nil, nil, usererror.BadRequest("Pull request already merged") } if pr.State != enum.PullReqStateOpen { - return types.MergeResponse{}, usererror.BadRequest("Pull request must be open") + return nil, nil, usererror.BadRequest("Pull request must be open") } /* if pr.SourceSHA != in.SourceSHA { - return types.MergeResponse{}, + return nil, nil, usererror.BadRequest("A newer commit is available. Only the latest commit can be merged.") } */ if pr.IsDraft { - return types.MergeResponse{}, usererror.BadRequest( + return nil, nil, usererror.BadRequest( "Draft pull requests can't be merged. Clear the draft flag first.", ) } reviewers, err := c.reviewerStore.List(ctx, pr.ID) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to load list of reviwers: %w", err) + return nil, nil, fmt.Errorf("failed to load list of reviwers: %w", err) } targetWriteParams, err := controller.CreateRPCInternalWriteParams(ctx, c.urlProvider, session, targetRepo) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to create RPC write params: %w", err) + return nil, nil, fmt.Errorf("failed to create RPC write params: %w", err) } sourceRepo := targetRepo @@ -123,28 +118,28 @@ func (c *Controller) Merge( if pr.SourceRepoID != pr.TargetRepoID { sourceWriteParams, err = controller.CreateRPCInternalWriteParams(ctx, c.urlProvider, session, sourceRepo) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to create RPC write params: %w", err) + return nil, nil, fmt.Errorf("failed to create RPC write params: %w", err) } sourceRepo, err = c.repoStore.Find(ctx, pr.SourceRepoID) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to get source repository: %w", err) + return nil, nil, fmt.Errorf("failed to get source repository: %w", err) } } isSpaceOwner, err := apiauth.IsSpaceAdmin(ctx, c.authorizer, session, targetRepo) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to determine if the user is space admin: %w", err) + return nil, nil, fmt.Errorf("failed to determine if the user is space admin: %w", err) } checkResults, err := c.checkStore.ListResults(ctx, targetRepo.ID, pr.SourceSHA) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to list status checks: %w", err) + return nil, nil, fmt.Errorf("failed to list status checks: %w", err) } protectionRules, err := c.protectionManager.ForRepository(ctx, targetRepo.ID) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to fetch protection rules for the repository: %w", err) + return nil, nil, fmt.Errorf("failed to fetch protection rules for the repository: %w", err) } ruleOut, violations, err := protectionRules.CanMerge(ctx, protection.CanMergeInput{ @@ -158,10 +153,10 @@ func (c *Controller) Merge( CheckResults: checkResults, }) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to verify protection rules: %w", err) + return nil, nil, fmt.Errorf("failed to verify protection rules: %w", err) } if protection.IsCritical(violations) { - return types.MergeResponse{RuleViolations: violations}, nil + return nil, &types.MergeViolations{RuleViolations: violations}, nil } // TODO: for forking merge title might be different? @@ -192,12 +187,22 @@ func (c *Controller) Merge( }) if err != nil { if gitrpc.ErrorStatus(err) == gitrpc.StatusNotMergeable { - return types.MergeResponse{ + return &types.MergeResponse{ + SHA: "", + BranchDeleted: false, ConflictFiles: gitrpc.AsConflictFilesError(err), RuleViolations: violations, - }, nil + }, nil, nil + // TODO: This should be the response in case of a merge conflict. + // TODO: Remove the ConflictFiles field from types.MergeResponse. + /* + return nil, &types.MergeViolations{ + ConflictFiles: gitrpc.AsConflictFilesError(err), + RuleViolations: violations, + }, nil + */ } - return types.MergeResponse{}, fmt.Errorf("merge check execution failed: %w", err) + return nil, nil, fmt.Errorf("merge check execution failed: %w", err) } pr, err = c.pullreqStore.UpdateOptLock(ctx, pr, func(pr *types.PullReq) error { @@ -219,7 +224,7 @@ func (c *Controller) Merge( return nil }) if err != nil { - return types.MergeResponse{}, fmt.Errorf("failed to update pull request: %w", err) + return nil, nil, fmt.Errorf("failed to update pull request: %w", err) } activityPayload := &types.PullRequestActivityPayloadMerge{ @@ -255,9 +260,9 @@ func (c *Controller) Merge( } } - return types.MergeResponse{ - SHA: sha, + return &types.MergeResponse{ + SHA: mergeOutput.MergeSHA, BranchDeleted: branchDeleted, RuleViolations: violations, - }, nil + }, nil, nil } diff --git a/app/api/controller/repo/commit.go b/app/api/controller/repo/commit.go index 38b3c68b4..faa7edeb1 100644 --- a/app/api/controller/repo/commit.go +++ b/app/api/controller/repo/commit.go @@ -57,20 +57,20 @@ func (c *Controller) CommitFiles(ctx context.Context, session *auth.Session, repoRef string, in *CommitFilesOptions, -) (types.CommitFilesResponse, error) { +) (types.CommitFilesResponse, []types.RuleViolations, error) { repo, err := c.getRepoCheckAccess(ctx, session, repoRef, enum.PermissionRepoPush, false) if err != nil { - return types.CommitFilesResponse{}, err + return types.CommitFilesResponse{}, nil, err } isSpaceOwner, err := apiauth.IsSpaceAdmin(ctx, c.authorizer, session, repo) if err != nil { - return types.CommitFilesResponse{}, err + return types.CommitFilesResponse{}, nil, err } protectionRules, err := c.protectionManager.ForRepository(ctx, repo.ID) if err != nil { - return types.CommitFilesResponse{}, + return types.CommitFilesResponse{}, nil, fmt.Errorf("failed to fetch protection rules for the repository: %w", err) } @@ -93,11 +93,11 @@ func (c *Controller) CommitFiles(ctx context.Context, RefNames: []string{branchName}, }) if err != nil { - return types.CommitFilesResponse{}, fmt.Errorf("failed to verify protection rules for git push: %w", err) + return types.CommitFilesResponse{}, nil, fmt.Errorf("failed to verify protection rules: %w", err) } if protection.IsCritical(violations) { - return types.CommitFilesResponse{RuleViolations: violations}, nil + return types.CommitFilesResponse{}, violations, nil } actions := make([]gitrpc.CommitFileAction, len(in.Actions)) @@ -107,7 +107,7 @@ func (c *Controller) CommitFiles(ctx context.Context, case enum.ContentEncodingTypeBase64: rawPayload, err = base64.StdEncoding.DecodeString(action.Payload) if err != nil { - return types.CommitFilesResponse{}, fmt.Errorf("failed to decode base64 payload: %w", err) + return types.CommitFilesResponse{}, nil, fmt.Errorf("failed to decode base64 payload: %w", err) } case enum.ContentEncodingTypeUTF8: fallthrough @@ -127,7 +127,7 @@ func (c *Controller) CommitFiles(ctx context.Context, // Create internal write params. Note: This will skip the pre-commit protection rules check. writeParams, err := controller.CreateRPCInternalWriteParams(ctx, c.urlProvider, session, repo) if err != nil { - return types.CommitFilesResponse{}, fmt.Errorf("failed to create RPC write params: %w", err) + return types.CommitFilesResponse{}, nil, fmt.Errorf("failed to create RPC write params: %w", err) } now := time.Now() @@ -144,9 +144,10 @@ func (c *Controller) CommitFiles(ctx context.Context, AuthorDate: &now, }) if err != nil { - return types.CommitFilesResponse{}, err + return types.CommitFilesResponse{}, nil, err } + return types.CommitFilesResponse{ CommitID: commit.CommitID, - }, nil + }, nil, nil } diff --git a/app/api/handler/pullreq/merge.go b/app/api/handler/pullreq/merge.go index 74d4314f8..f9513a6ec 100644 --- a/app/api/handler/pullreq/merge.go +++ b/app/api/handler/pullreq/merge.go @@ -50,11 +50,15 @@ func HandleMerge(pullreqCtrl *pullreq.Controller) http.HandlerFunc { return } - pr, err := pullreqCtrl.Merge(ctx, session, repoRef, pullreqNumber, in) + pr, violation, err := pullreqCtrl.Merge(ctx, session, repoRef, pullreqNumber, in) if err != nil { render.TranslatedUserError(w, err) return } + if violation != nil { + render.Unprocessable(w, violation) + return + } render.JSON(w, http.StatusOK, pr) } diff --git a/app/api/handler/repo/commit.go b/app/api/handler/repo/commit.go index d430fbbdd..f83c18916 100644 --- a/app/api/handler/repo/commit.go +++ b/app/api/handler/repo/commit.go @@ -40,11 +40,15 @@ func HandleCommitFiles(repoCtrl *repo.Controller) http.HandlerFunc { render.BadRequestf(w, "Invalid request body: %s.", err) return } - response, err := repoCtrl.CommitFiles(ctx, session, repoRef, in) + response, violations, err := repoCtrl.CommitFiles(ctx, session, repoRef, in) if err != nil { render.TranslatedUserError(w, err) return } + if violations != nil { + render.Violations(w, violations) + return + } render.JSON(w, http.StatusOK, response) } diff --git a/app/api/openapi/pullreq.go b/app/api/openapi/pullreq.go index ff2fa2a3a..dd62043ba 100644 --- a/app/api/openapi/pullreq.go +++ b/app/api/openapi/pullreq.go @@ -456,6 +456,7 @@ func pullReqOperations(reflector *openapi3.Reflector) { _ = reflector.SetJSONResponse(&reviewSubmit, new(usererror.Error), http.StatusForbidden) _ = reflector.Spec.AddOperation(http.MethodPost, "/repos/{repo_ref}/pullreq/{pullreq_number}/reviews", reviewSubmit) + mergePullReqOp := openapi3.Operation{} mergePullReqOp.WithTags("pullreq") mergePullReqOp.WithMapOfAnything(map[string]interface{}{"operationId": "mergePullReqOp"}) @@ -467,7 +468,7 @@ func pullReqOperations(reflector *openapi3.Reflector) { _ = reflector.SetJSONResponse(&mergePullReqOp, new(usererror.Error), http.StatusNotFound) _ = reflector.SetJSONResponse(&mergePullReqOp, new(usererror.Error), http.StatusMethodNotAllowed) _ = reflector.SetJSONResponse(&mergePullReqOp, new(usererror.Error), http.StatusConflict) - _ = reflector.SetJSONResponse(&mergePullReqOp, new(usererror.Error), http.StatusUnprocessableEntity) + _ = reflector.SetJSONResponse(&mergePullReqOp, new(types.MergeViolations), http.StatusUnprocessableEntity) _ = reflector.Spec.AddOperation(http.MethodPost, "/repos/{repo_ref}/pullreq/{pullreq_number}/merge", mergePullReqOp) diff --git a/app/api/openapi/repo.go b/app/api/openapi/repo.go index 7ec0ea70e..fb3046cfc 100644 --- a/app/api/openapi/repo.go +++ b/app/api/openapi/repo.go @@ -689,6 +689,7 @@ func repoOperations(reflector *openapi3.Reflector) { _ = reflector.SetJSONResponse(&opCommitFiles, new(usererror.Error), http.StatusUnauthorized) _ = reflector.SetJSONResponse(&opCommitFiles, new(usererror.Error), http.StatusForbidden) _ = reflector.SetJSONResponse(&opCommitFiles, new(usererror.Error), http.StatusNotFound) + _ = reflector.SetJSONResponse(&opCommitFiles, new(types.RulesViolations), http.StatusUnprocessableEntity) _ = reflector.Spec.AddOperation(http.MethodPost, "/repos/{repo_ref}/commits", opCommitFiles) opDiff := openapi3.Operation{} diff --git a/app/api/render/render.go b/app/api/render/render.go index 10893e43d..e1fdbc61b 100644 --- a/app/api/render/render.go +++ b/app/api/render/render.go @@ -65,12 +65,7 @@ func BadRequest(w http.ResponseWriter) { UserError(w, usererror.ErrBadRequest) } -// BadRequestError writes the json-encoded error with a bad request status code. -func BadRequestError(w http.ResponseWriter, err *usererror.Error) { - UserError(w, err) -} - -// BadRequest writes the json-encoded message with a bad request status code. +// BadRequestf writes the json-encoded message with a bad request status code. func BadRequestf(w http.ResponseWriter, format string, args ...interface{}) { ErrorMessagef(w, http.StatusBadRequest, format, args...) } @@ -98,21 +93,9 @@ func DeleteSuccessful(w http.ResponseWriter) { // JSON writes the json-encoded value to the response // with the provides status. func JSON(w http.ResponseWriter, code int, v interface{}) { - // set common headers - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - - // flush the headers - before body or status will be 200 OK + setCommonHeaders(w) w.WriteHeader(code) - - // write body - enc := json.NewEncoder(w) - if indent { // is this necessary? it will affect performance - enc.SetIndent("", " ") - } - if err := enc.Encode(v); err != nil { - log.Err(err).Msgf("Failed to write json encoding to response body.") - } + writeJSON(w, v) } // Reader reads the content from the provided reader and writes it as is to the response body. @@ -157,8 +140,7 @@ func JSONArrayDynamic[T comparable](ctx context.Context, w http.ResponseWriter, } if count == 0 { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") + setCommonHeaders(w) w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte{'['}) } else { @@ -176,3 +158,28 @@ func JSONArrayDynamic[T comparable](ctx context.Context, w http.ResponseWriter, _, _ = w.Write([]byte{']'}) } + +func Unprocessable(w http.ResponseWriter, v any) { + JSON(w, http.StatusUnprocessableEntity, v) +} + +func Violations(w http.ResponseWriter, violations []types.RuleViolations) { + Unprocessable(w, types.RulesViolations{ + Violations: violations, + }) +} + +func setCommonHeaders(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") +} + +func writeJSON(w http.ResponseWriter, v any) { + enc := json.NewEncoder(w) + if indent { + enc.SetIndent("", " ") + } + if err := enc.Encode(v); err != nil { + log.Err(err).Msgf("Failed to write json encoding to response body.") + } +} diff --git a/types/commit.go b/types/commit.go index 0156fb0fb..519907bc0 100644 --- a/types/commit.go +++ b/types/commit.go @@ -16,6 +16,5 @@ package types // CommitFilesResponse holds commit id. type CommitFilesResponse struct { - CommitID string `json:"commit_id"` - RuleViolations []RuleViolations `json:"rule_violations,omitempty"` + CommitID string `json:"commit_id"` } diff --git a/types/pullreq.go b/types/pullreq.go index d93a79730..4d2c1d9f5 100644 --- a/types/pullreq.go +++ b/types/pullreq.go @@ -144,3 +144,8 @@ type MergeResponse struct { ConflictFiles []string `json:"conflict_files,omitempty"` RuleViolations []RuleViolations `json:"rule_violations,omitempty"` } + +type MergeViolations struct { + ConflictFiles []string `json:"conflict_files,omitempty"` + RuleViolations []RuleViolations `json:"rule_violations,omitempty"` +} diff --git a/types/rule.go b/types/rule.go index 45630edc1..5fa02bfbd 100644 --- a/types/rule.go +++ b/types/rule.go @@ -102,3 +102,7 @@ type RuleInfoInternal struct { Pattern json.RawMessage Definition json.RawMessage } + +type RulesViolations struct { + Violations []RuleViolations `json:"violations"` +}