diff --git a/apps/workflows/api.py b/apps/workflows/api.py index 51df285f3..47122f738 100644 --- a/apps/workflows/api.py +++ b/apps/workflows/api.py @@ -13,7 +13,12 @@ from apps.taskman.service import JiraTaskmanQuerier from osidb.api_views import RudimentaryUserPathLoggingMixin, get_valid_http_methods -from osidb.helpers import get_bugzilla_api_key, get_flaw_or_404, get_jira_api_key +from osidb.helpers import ( + get_bugzilla_api_key, + get_flaw_or_404, + get_flaw_with_related_objects, + get_jira_api_key, +) from .exceptions import WorkflowsException from .helpers import str2bool @@ -110,7 +115,9 @@ def post(self, request, flaw_id): return its workflow:state classification or errors if not possible to promote """ logger.info(f"promoting flaw {flaw_id} workflow classification") - flaw = get_flaw_or_404(flaw_id) + # Use optimized queryset with prefetch_related to minimize database queries + optimized_queryset = get_flaw_with_related_objects() + flaw = get_flaw_or_404(flaw_id, queryset=optimized_queryset) try: jira_token = get_jira_api_key(request) bz_token = get_bugzilla_api_key(request) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 5859b11a8..4ebe9fef2 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Uniqueness on purl or ps_component for affects (OSIDB-4644) - Make `ps_module` read-only in the affect serializer (OSIDB-4554) - Allow filtering by PURL fields (OSIDB-4728) +- Improve promote performance with prefetch (OSIDB-4678) +- Improve promote performance with ACL bulk update (OSIDB-4678) ### Added - Add `in` filter for API views (OSIDB-4588) diff --git a/osidb/helpers.py b/osidb/helpers.py index 28c14cb2e..9a61ef73f 100644 --- a/osidb/helpers.py +++ b/osidb/helpers.py @@ -85,6 +85,33 @@ def get_flaw_or_404(pk, queryset=None): raise Http404 from e +def get_flaw_with_related_objects(): + """ + Returns a Flaw queryset with prefetch_related and select_related + to minimize database queries during workflow promotion operations. + """ + from osidb.models import Flaw + + return Flaw.objects.prefetch_related( + # Flaw's direct reverse relationships + "acknowledgments", + "affects", + "comments", + "cvss_scores", + "package_versions", + "references", + "labels", + "alerts", + # Affects and their related objects + "affects__cvss_scores", + "affects__tracker", + "affects__tracker__errata", + "affects__tracker__affects", + "affects__alerts", + "affects__tracker__alerts", + ) + + # Replaces strtobool from the deprecated distutils library def strtobool(val: str): val = val.lower() diff --git a/osidb/mixins.py b/osidb/mixins.py index 3c70c4a34..91dfc2713 100644 --- a/osidb/mixins.py +++ b/osidb/mixins.py @@ -1,7 +1,8 @@ import logging import uuid +from collections import defaultdict from functools import cached_property -from itertools import chain +from itertools import batched, chain import pghistory import pgtrigger @@ -366,6 +367,7 @@ def set_public(self): >>> my_flaw.acl_read ... [UUID(...), UUID(...)] """ + self.set_acl_read(*settings.PUBLIC_READ_GROUPS) self.set_acl_write(settings.PUBLIC_WRITE_GROUP) # Update the embargoed annotation to reflect the new ACL state @@ -669,32 +671,76 @@ def unembargo(self): if related_instance: related_instance.unembargo() - def set_public_nested(self): + def set_public_nested(self, max_chunk_size=5000): """ Change internal ACLs to public ACLs for all related Flaw objects and save them. The only exception is "snippets", which should always have internal ACLs. The Flaw itself will be saved later to avoid duplicate operations. + + This method collects all objects that need to be updated and performs + chunked queryset updates to minimize database queries and avoid storing + model instances in memory. + """ + objects_to_update = defaultdict(set) # {ModelClass: {pk, ...}} + + visited = set() + self._collect_objects_for_public_update(objects_to_update, visited) + + if not any(objects_to_update.values()): + return + + # Cut off microseconds to match TrackingMixin.save() behavior + now = timezone.now().replace(microsecond=0) + + public_acl_read = [ + uuid.UUID(acl) for acl in generate_acls(settings.PUBLIC_READ_GROUPS) + ] + public_acl_write = [ + uuid.UUID(acl) for acl in generate_acls([settings.PUBLIC_WRITE_GROUP]) + ] + + for model_class, object_ids in objects_to_update.items(): + if not object_ids: + continue + + if not issubclass(model_class, ACLMixin): + continue + + update_kwargs = {"acl_read": public_acl_read, "acl_write": public_acl_write} + if issubclass(model_class, TrackingMixin): + update_kwargs["updated_dt"] = now + + for pk_chunk in batched(object_ids, max_chunk_size): + model_class.objects.filter(pk__in=pk_chunk).update(**update_kwargs) + + for pk in object_ids: + model_class(pk=pk).set_history_public() + + def _collect_objects_for_public_update(self, objects_to_update, visited): + """ + Recursively collect all related objects that need to have their ACLs + updated to public. The Flaw itself is not collected as it's saved separately. + + Uses a defaultdict(set) keyed by model class to collect pks. + Uses a set of (type, pk) to prevent infinite recursion. """ from osidb.models import Flaw + if self.pk is None: + return + + visit_key = (type(self), self.pk) + + if visit_key in visited: + return + + visited.add(visit_key) + if not isinstance(self, Flaw): if not self.is_internal: return - kwargs = {} - if issubclass(type(self), AlertMixin): - # suppress the validation errors as we expect that during - # the update the parent and child ACLs will not equal - kwargs["raise_validation_error"] = False - if issubclass(type(self), TrackingMixin): - # do not auto-update the updated_dt timestamp as the - # followup update would fail on a mid-air collision - kwargs["auto_timestamps"] = False - self.set_public() - self.set_history_public() - self.save(**kwargs) + objects_to_update[type(self)].add(self.pk) - # chain all the related instances in reverse relationships (o2m, m2m) - # as we only care for the ACLs which are unified for related_instance in chain.from_iterable( getattr(self, name).all() for name in [ @@ -707,17 +753,18 @@ def set_public_nested(self): ) ] ): - # continue deeper into the related context - related_instance.set_public_nested() - - # chain related instances in forward relationships (m2o, o2o) + related_instance._collect_objects_for_public_update( + objects_to_update, visited + ) for field in self._meta.concrete_fields: if isinstance( field, (models.ForeignKey, models.OneToOneField) ) and issubclass(field.related_model, ACLMixin): related_instance = getattr(self, field.name) if related_instance: - related_instance.set_public_nested() + related_instance._collect_objects_for_public_update( + objects_to_update, visited + ) class AlertManager(ACLMixinManager): diff --git a/osidb/tests/test_query_regresion.py b/osidb/tests/test_query_regresion.py index bc5f4d1fe..d9ae7c6fb 100644 --- a/osidb/tests/test_query_regresion.py +++ b/osidb/tests/test_query_regresion.py @@ -4,8 +4,9 @@ from django.test.utils import CaptureQueriesContext from pytest_django.asserts import assertNumQueries +from apps.workflows.workflow import WorkflowModel from osidb.api_views import FlawView -from osidb.models import Affect, Flaw, Impact, Tracker +from osidb.models import Affect, Flaw, FlawSource, Impact, Tracker from osidb.tests.factories import ( AffectFactory, FlawFactory, @@ -385,3 +386,111 @@ def _spy_prefetch_related(self, *lookups): assert not prefetch_calls, ( f"Unexpected affects prefetches during PUT: {prefetch_calls}" ) + + @pytest.mark.enable_signals + @pytest.mark.parametrize( + "embargoed,affect_quantity,expected_queries", + [ + (True, 1, 76), + (True, 10, 76), + (True, 100, 76), + (False, 1, 113), # down from 119 + (False, 10, 311), # down from 389 + (False, 100, 2291), # down from 3089 + ], + ) + def test_flaw_promote( + self, + auth_client, + enable_jira_task_async_sync, + test_api_uri, + jira_token, + bugzilla_token, + embargoed, + affect_quantity, + expected_queries, + ): + """ + Test query performance for flaws promote endpoint as number of affects increases. + """ + ps_module = PsModuleFactory() + ps_update_stream = PsUpdateStreamFactory(ps_module=ps_module) + + flaw = FlawFactory( + embargoed=embargoed, + impact=Impact.MODERATE, + major_incident_state=Flaw.FlawMajorIncident.NOVALUE, + ) + if not embargoed: + flaw.set_internal() + flaw.save() + + for _ in range(affect_quantity): + affect = AffectFactory( + flaw=flaw, + ps_update_stream=ps_update_stream.name, + affectedness=Affect.AffectAffectedness.AFFECTED, + resolution=Affect.AffectResolution.DELEGATED, + impact=Impact.MODERATE, + ) + + TrackerFactory( + affects=[affect], + embargoed=embargoed, + ps_update_stream=ps_update_stream.name, + type=Tracker.BTS2TYPE[ps_module.bts_name], + ) + # Make children internal in the non-embargoed scenario so set_public_nested has work to do + if not embargoed: + # Ensure related objects start as internal, then promotion will flip them public + affect.set_internal() + affect.save(raise_validation_error=False) + if affect.tracker: + affect.tracker.set_internal() + affect.tracker.save(raise_validation_error=False) + + # Force initial classification to start the promote chain from NEW. + # Even with signals enabled, setting task_key later can trigger workflow + # auto-adjust and skip ahead because required fields are already filled. + flaw.classification = { + "workflow": "DEFAULT", + "state": WorkflowModel.WorkflowState.NEW, + } + + # NEW -> TRIAGE requires owner + flaw.owner = "Alice" + + # TRIAGE -> PRE_SECONDARY_ASSESSMENT requires source and title + flaw.source = FlawSource.CUSTOMER + flaw.title = flaw.title or "Sample title" + flaw.save(raise_validation_error=False) + + # Ensure a Jira task exists so workflow transitions trigger adjust_acls/set_public_nested + flaw.task_key = "OSIM-1" + flaw.save(raise_validation_error=False) + + headers = { + "HTTP_JIRA_API_KEY": jira_token, + "HTTP_BUGZILLA_API_KEY": bugzilla_token, + } + + # Promote to TRIAGE + response = auth_client().post( + f"{test_api_uri}/flaws/{flaw.uuid}/promote", + data={}, + format="json", + **headers, + ) + + assert response.status_code == 200 + + # Promote to PRE_SECONDARY_ASSESSMENT using async task sync + # this one runs the nested set_public_nested call and set_history_public + + with assertNumQueriesLessThan(expected_queries): + response = auth_client().post( + f"{test_api_uri}/flaws/{flaw.uuid}/promote", + data={}, + format="json", + **headers, + )