diff --git a/.gitignore b/.gitignore index f0fd22193..b159d8c7c 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,6 @@ osidb_data_backup_dump* .DS_Store # pyright local configuration file pyrightconfig.json + +# performance tests output +performance_report* diff --git a/.secrets.baseline b/.secrets.baseline index 259116c70..f46f12b67 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -470,10 +470,10 @@ "filename": "tox.ini", "hashed_secret": "5a4fe08359c7f97380e408c717ef42c86939cd86", "is_verified": false, - "line_number": 51, + "line_number": 54, "is_secret": false } ] }, - "generated_at": "2025-11-25T20:59:22Z" + "generated_at": "2025-11-27T15:08:04Z" } diff --git a/osidb/tests/performance/__init__.py b/osidb/tests/performance/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/osidb/tests/performance/auditor.py b/osidb/tests/performance/auditor.py new file mode 100644 index 000000000..73c06762d --- /dev/null +++ b/osidb/tests/performance/auditor.py @@ -0,0 +1,165 @@ +import cProfile +import time +from collections import Counter + +from django.conf import settings +from django.db import connection, reset_queries +from django.test.utils import CaptureQueriesContext + +from .utils import clean_sql, extract_tables, fingerprint_sql + +SLOW_QUERY_THRESHOLD = 0.04 +N_PLUS_ONE_THRESHOLD = 3 + + +class PerformanceAuditor: + """ + Context manager for comprehensive performance auditing. + + Tracks: + - Database queries (count, duplicates, N+1 patterns, slow queries, mutations) + - Execution time (total, Python, database) + - CPU Profiling + """ + + def __init__(self, enable_profiling=True): + self.total_time = 0 + self.db_time = 0 + self.exact_duplicates = [] + self.query_map = {} + self.n_plus_one_suspects = [] + self.writes_detected = [] + self.table_counts = Counter() + self.slow_query_plans = [] + + self.profiler = cProfile.Profile() if enable_profiling else None + + def __enter__(self): + # Ensure DEBUG is disabled to mimic a production environment + self.old_debug = settings.DEBUG + settings.DEBUG = False + + # Reset query count and enable capture + reset_queries() + self.queries_ctx = CaptureQueriesContext(connection).__enter__() + + self.start_time = time.perf_counter() + self.start_cpu = time.process_time() + + if self.profiler: + self.profiler.enable() + + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.profiler: + self.profiler.disable() + + # Calculate times + self.end_time = time.perf_counter() + self.end_cpu = time.process_time() + self.total_time = self.end_time - self.start_time + self.cpu_time = self.end_cpu - self.start_cpu + + # Restore DEBUG + settings.DEBUG = self.old_debug + + # Stop query capture + self.queries_ctx.__exit__(exc_type, exc_value, traceback) + self.queries = self.queries_ctx.captured_queries + + self._analyze_results() + + def _analyze_results(self): + """ + Process the raw data into more meaningful stats. + """ + + for query in self.queries: + sql = clean_sql(query["sql"]) + + # Calculate DB time (Django stores it as string seconds) + duration = float(query["time"]) + self.db_time += duration + + # Detect mutations + if any( + action in sql for action in ["INSERT INTO", "UPDATE ", "DELETE FROM"] + ): + # Ignore savepoints/transaction management + if "auth_" not in sql: # ignore session updates + self.writes_detected.append(sql) + + # Extract tables accessed + tables = extract_tables(sql) + for table in tables: + self.table_counts[table] += 1 + + # Fingerprint query shape + fp = fingerprint_sql(sql) + if fp not in self.query_map: + self.query_map[fp] = { + "durations": [], + "sql": sql, # Capture the First raw SQL with this shape + } + self.query_map[fp]["durations"].append(duration) + + # Analyze slow queries + if duration > SLOW_QUERY_THRESHOLD: + self._run_explain(sql, duration) + + # Detect duplicated queries + raw_sql_counts = Counter([clean_sql(q["sql"]) for q in self.queries]) + for sql, count in raw_sql_counts.items(): + if count > 1 and "SAVEPOINT" not in sql: + self.exact_duplicates.append({"sql": sql, "count": count}) + + # Detect N+1 + # If the same query structure runs > N_PLUS_ONE_THRESHOLD times, flag it + for fp, data in self.query_map.items(): + durations = data["durations"] + count = len(durations) + if count > N_PLUS_ONE_THRESHOLD: + self.n_plus_one_suspects.append( + { + "fingerprint": fp, + "sql": data["sql"], + "count": count, + "avg_time": sum(durations) / count, + "total_time": sum(durations), + } + ) + + def _run_explain(self, sql, duration): + # We need to be careful not to break the transaction state + # Usually safe for SELECTs in tests + try: + with connection.cursor() as cursor: + cursor.execute(f"EXPLAIN ANALYZE {sql}") + plan = cursor.fetchall() + # Flatten the result into a single string + plan_text = "\n".join([row[0] for row in plan]) + + self.slow_query_plans.append( + {"sql": sql, "duration": duration, "plan": plan_text} + ) + except Exception as e: + # Don't crash the test if EXPLAIN fails (syntax errors, etc) + print(f"Could not explain query: {e}") + + def get_summary(self): + """Returns a dictionary summary for the report generator.""" + + return { + "total_duration": self.total_time, + "cpu_duration": self.cpu_time, + "db_duration": self.db_time, + "query_count": len(self.queries), + "exact_duplicates": self.exact_duplicates, + "n_plus_one_suspects": self.n_plus_one_suspects, + "writes_detected": self.writes_detected, + "table_breakdown": dict( + self.table_counts.most_common() + ), # Sorts by most accessed + "slow_query_plans": self.slow_query_plans, + } diff --git a/osidb/tests/performance/conftest.py b/osidb/tests/performance/conftest.py new file mode 100644 index 000000000..97b680b1b --- /dev/null +++ b/osidb/tests/performance/conftest.py @@ -0,0 +1,271 @@ +import os +from datetime import datetime + +import pytest + +from .auditor import PerformanceAuditor +from .utils import get_profile_stats, get_safe_filename + +# Global store for results (Note: incompatible with pytest-xdist parallelization without extra work) +PERFORMANCE_RESULTS = [] + + +@pytest.fixture(scope="function") +def performance_audit(request): + auditor = PerformanceAuditor() + + # Run the test + yield auditor + + # Check if the tests actually used the auditor + # and skip reporting otherwise + if auditor.total_time == 0: + return + + summary = auditor.get_summary() + summary["test_name"] = request.node.name + summary["profile"] = get_profile_stats(auditor.profiler) + # We could export the profile with auditor.profiler.dump_stats() + # for manual review if needed + + PERFORMANCE_RESULTS.append(summary) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + """ + Hook to print the summary table at the end of the session. + """ + if not PERFORMANCE_RESULTS: + return + + terminalreporter.section("Performance Regression Report") + + # Table Header + header = f"{'Test Name':<40} | {'Time':<8} | {'CPU Time':<8} | {'DB Time':<8} | {'Queries':<7} | {'Tables':<6} | {'N+1':<4} | {'Writes':<6} | {'Dup Queries':<11} | {'Slow queries':<12}" + terminalreporter.write_line("-" * len(header)) + terminalreporter.write_line(header) + terminalreporter.write_line("-" * len(header)) + + # Table Rows + for res in PERFORMANCE_RESULTS: + name = ( + res["test_name"] + if len(res["test_name"]) <= 40 + else f"{res['test_name'][:37]}..." + ) + row = ( + f"{name:<40} | " + f"{res['total_duration'] * 1000:>6.0f}ms | " + f"{res['cpu_duration'] * 1000:>6.0f}ms | " + f"{res['db_duration'] * 1000:>6.0f}ms | " + f"{res['query_count']:>7} | " + f"{len(res['table_breakdown'].keys()):>6} | " + f"{len(res['n_plus_one_suspects']):>4} | " + f"{len(res['writes_detected']):>6} | " + f"{len(res['exact_duplicates']):>11} | " + f"{len(res['slow_query_plans']):>12}" + ) + terminalreporter.write_line(row) + generate_markdown_report() + + +def generate_markdown_report(): + """Generates a detailed GITHUB_STEP_SUMMARY compliant markdown file.""" + + def get_color(time): + if time < 100: + return "green" + elif time < 200: + return "blue" + elif time < 500: + return "orange" + else: + return "red" + + run_date = datetime.now() + + if ( + "CI" not in os.environ + or not os.environ["CI"] + or "GITHUB_RUN_ID" not in os.environ + ): + file = f"performance_report_{run_date.strftime('%Y-%m-%d_%H-%M')}.md" + mode = "w" + else: + file = os.environ["GITHUB_STEP_SUMMARY"] + mode = "a" + + with open(file, mode) as report: + report.write("# 🚀 Detailed Performance Analysis\n\n") + + # Summary table + report.write("## 📊 Executive Summary\n") + report.write( + "| Test | Time | CPU Time | DB Time | Queries | Tables | N+1 | Writes | Dup Queries | Slow queries |\n" + ) + report.write("|" + ("---|" * 10) + "\n") + for res in PERFORMANCE_RESULTS: + report.write( + f"| [{res['test_name']}](#{get_safe_filename(res['test_name'])})" + f"| {res['total_duration'] * 1000:.0f}ms" + f"| {res['cpu_duration'] * 1000:.0f}ms" + f"| {res['db_duration'] * 1000:.0f}ms" + f"| {res['query_count']:>7}" + f"| {len(res['table_breakdown'].keys())}" + f"| {len(res['n_plus_one_suspects'])}" + f"| {len(res['writes_detected'])}" + f"| {len(res['exact_duplicates'])}" + f"| {len(res['slow_query_plans'])}|\n" + ) + + # Detailed breakdown per test + report.write("\n---\n") + report.write("## 🔬 Deep Dive per Test\n") + for res in PERFORMANCE_RESULTS: + total_duration = res["total_duration"] * 1000 + cpu_duration = res["cpu_duration"] * 1000 + db_duration = res["db_duration"] * 1000 + report.write( + f'\n

{res["test_name"]}

\n\n' + ) + report.write( + f"![Total Time](https://img.shields.io/badge/total-{total_duration:.0f}ms-{get_color(total_duration)}) " + ) + report.write( + f"![CPU Time](https://img.shields.io/badge/cpu-{cpu_duration:.0f}ms-{get_color(cpu_duration)}) " + ) + report.write( + f"![DB Time](https://img.shields.io/badge/db-{db_duration:.0f}ms-{get_color(db_duration)})\n\n" + ) + + if res["table_breakdown"]: + report.write("\n#### 📦 Database Model Access\n") + report.write("
\n") + report.write( + f"x{len(res['table_breakdown'].keys())} Tables accessed\n\n" + ) + report.write("| Table Name | Access Count |\n") + report.write("|---|---|\n") + for table, count in res["table_breakdown"].items(): + report.write(f"| `{table}` | {count} |\n") + report.write("
\n\n") + + writes = res.get("writes_detected", []) + if writes: + report.write("\n#### ⚠️ Mutation Warning\n") + report.write( + f"This endpoint triggered **{len(writes)}** write operations.\n\n" + ) + + report.write("
\nView Write Operations\n\n") + for sql in writes: + # Truncate very long SQL for readability + display_sql = sql[:200] + "..." if len(sql) > 200 else sql + report.write(f"- `{display_sql}`\n") + report.write("
\n") + + dupes = res.get("exact_duplicates", []) + if dupes: + report.write("\n#### ♻️ Redundant Exact Queries\n") + report.write( + "The exact same SQL (same parameters) was executed multiple times.\n\n" + ) + report.write("| Count | Query Sample |\n") + report.write("|---|---|\n") + for d in dupes: + # Clean up newlines for table formatting + clean_sql = d["sql"].replace("\n", " ").strip() + # Truncate center if too long + if len(clean_sql) > 80: + short_sql = clean_sql[:40] + " ... " + clean_sql[-35:] + else: + short_sql = clean_sql + report.write(f"| **x{d['count']}** | `{short_sql}` |\n") + + if len(res["n_plus_one_suspects"]) > 0: + report.write("#### ⚠️ N+1 Detected\n") + report.write( + "The same SQL was executed multiple times with different parameters.\n\n" + ) + for suspect in res["n_plus_one_suspects"]: + report.write("
\n") + report.write( + f"x{suspect['count']} queries\n\n" + ) + report.write("**Fingerprint:**\n") + report.write(f"```sql\n{suspect['fingerprint']}\n```\n") + report.write("**Example Raw SQL:**\n") + report.write(f"```sql\n{suspect['sql']}\n```\n") + report.write("
\n\n") + + if res["slow_query_plans"]: + report.write("\n#### 🐢 Slow Query Analysis\n") + report.write( + f"Found **{len(res['slow_query_plans'])}** queries exceeding the threshold.\n\n" + ) + + for i, item in enumerate(res["slow_query_plans"], 1): + duration_ms = item["duration"] * 1000 + report.write("
\n") + report.write( + f"{i}. Slow Query ({duration_ms:.1f}ms) - Click to view Plan\n\n" + ) + report.write("**SQL:**\n") + report.write(f"```sql\n{item['sql']}\n```\n") + report.write("**PostgreSQL Query Plan:**\n") + report.write(f"```yaml\n{item['plan']}\n```\n") + report.write("
\n\n") + + # CPU Profile Analysis + profile_data = res.get("profile") + if profile_data and profile_data.get("top_functions"): + report.write("\n#### 🔥 CPU Profile Analysis\n") + report.write( + f"**Total profiled time**: {profile_data['total_time'] * 1000:.1f}ms " + f"({profile_data['total_calls']:,} function calls)\n\n" + ) + + # Category breakdown (collapsible) + report.write("
\n") + report.write( + "Category Breakdown\n\n" + ) + report.write("| Category | Time (ms) | % of Total | Calls |\n") + report.write("|---|---:|---:|---:|\n") + for cat, stats in sorted( + profile_data["category_breakdown"].items(), + key=lambda x: x[1]["time"], + reverse=True, + ): + if stats["time"] > 0: + report.write( + f"| {cat.replace('_', ' ').title()} | " + f"{stats['time'] * 1000:.2f} | " + f"{stats['percent']:.1f}% | " + f"{stats['calls']:,} |\n" + ) + report.write("
\n\n") + + # Top 10 functions (collapsible) + report.write("
\n") + report.write( + "Top 10 Time-Consuming Functions\n\n" + ) + report.write("| Function | Category | Calls | Time | % | Per Call |\n") + report.write("|---|---|---:|---:|---:|---:|\n") + for func in profile_data["top_functions"]: + func_display = f"`{func['filename']}:{func['line_number']}` {func['function_name']}" + if len(func_display) > 60: + func_display = func_display[:57] + "..." + + report.write( + f"| {func_display} | " + f"{func['category']} | " + f"{func['call_count']:,} | " + f"{func['total_time'] * 1000:.2f}ms | " + f"{func['time_percent']:.1f}% | " + f"{func['per_call_time'] * 1000:.3f}ms |\n" + ) + report.write("
\n\n") + + report.write("\n---\n") diff --git a/osidb/tests/performance/test_basic_usage.py b/osidb/tests/performance/test_basic_usage.py new file mode 100644 index 000000000..a29189315 --- /dev/null +++ b/osidb/tests/performance/test_basic_usage.py @@ -0,0 +1,145 @@ +import pytest +from django.test import RequestFactory +from django.urls import resolve + +from osidb.api_views import FlawView +from osidb.models import Affect, Flaw, Impact, Tracker +from osidb.serializer import FlawSerializer +from osidb.tests.factories import ( + AffectFactory, + FlawFactory, + PsModuleFactory, + PsUpdateStreamFactory, + TrackerFactory, +) + +pytestmark = pytest.mark.perf + + +def test_flaw_details_with_client(auth_client, test_api_v2_uri, performance_audit): + """ + Sample test to demostrate the overhead introduced by the django client. + """ + flaw = FlawFactory( + embargoed=False, + impact=Impact.LOW, + major_incident_state=Flaw.FlawMajorIncident.NOVALUE, + ) + with performance_audit: + # Django client simulates a browser request, creating a user and setting + # all the permissions, doing a POST to /login... + # This is perfect for normal testing scenarios but it adds extra roundtrips + # when testing performance (time, queries...) + response = auth_client().get(f"{test_api_v2_uri}/flaws/{flaw.uuid}") + + # We should keep the asserts outside of the performance_audit context to + # not pollute the analysis with pytest code + assert response.status_code == 200 + + +def test_flaw_details_with_factory(test_api_v2_uri, performance_audit): + """ + Sample test to demostrate the overhead introduced by the django client. + + Compared to the test above using the client, it should be almost 200ms faster with 3 or for times + less database queries + """ + flaw = FlawFactory( + embargoed=False, + impact=Impact.LOW, + major_incident_state=Flaw.FlawMajorIncident.NOVALUE, + ) + + # RequestFactory is a Django helper for creating mock requests + # bypassing all the user creation/login + request = RequestFactory().get(f"{test_api_v2_uri}/flaws/{flaw.uuid}") + with performance_audit: + # Using the mock request we can call the view directly, passing the parameters that + # normally are extracted from the path. + response = FlawView.as_view({"get": "retrieve"})(request, id=str(flaw.uuid)) + assert response.status_code == 200 + + +@pytest.mark.parametrize( + "url", + [ + ("/flaws"), + ("/flaws?include_history=true"), + ("/flaws?exclude_fields=affects"), + ( + "/flaws?include_fields=cve_id,uuid,impact,source,created_dt,updated_dt,classification,title,unembargo_dt,embargoed,owner,labels" + ), + ("/affects"), + ("/affects?include_history=true"), + ], +) +def test_list_endpoints(url, auth_client, test_api_v2_uri, performance_audit): + # Setup code that will not be audited for performance + for _ in range(10): + flaw = FlawFactory( + embargoed=False, + impact=Impact.LOW, + major_incident_state=Flaw.FlawMajorIncident.NOVALUE, + ) + AffectFactory.create_batch( + 3, + flaw=flaw, + affectedness=Affect.AffectAffectedness.AFFECTED, + resolution=Affect.AffectResolution.DELEGATED, + impact=Impact.MODERATE, + ) + + # We can use the `resolve` function to get the View that would render that path + view = resolve(f"/osidb/api/v2{url.split('?')[0]}").func + request = RequestFactory().get(f"{test_api_v2_uri}{url}") + + with performance_audit: + response = view(request) + assert response.status_code == 200 + + +def test_fn_call(performance_audit): + """ + Test to show that we can not only test views, but any function + """ + flaw = FlawFactory( + embargoed=False, + impact=Impact.LOW, + major_incident_state=Flaw.FlawMajorIncident.NOVALUE, + ) + for _ in range(3): + ps_module = PsModuleFactory() + ps_update_stream = PsUpdateStreamFactory(ps_module=ps_module) + affect = AffectFactory( + flaw=flaw, + affectedness=Affect.AffectAffectedness.AFFECTED, + resolution=Affect.AffectResolution.DELEGATED, + impact=Impact.MODERATE, + ps_update_stream=ps_update_stream.name, + ) + for _ in range(3): + TrackerFactory( + affects=[affect], + embargoed=False, + ps_update_stream=ps_update_stream.name, + type=Tracker.BTS2TYPE[ps_module.bts_name], + ) + + flaw = Flaw.objects.get(pk=flaw.pk) + flaw_serializer = FlawSerializer(flaw) + + with performance_audit: + affects = flaw_serializer.get_affects(flaw) + assert len(affects) == 3 + + +def test_create_flaw(performance_audit): + flaw = FlawFactory.build( + embargoed=False, + impact=Impact.LOW, + major_incident_state=Flaw.FlawMajorIncident.NOVALUE, + ) + + with performance_audit: + flaw.save() + assert Flaw.objects.filter(cve_id=flaw.cve_id).count() == 1 diff --git a/osidb/tests/performance/utils.py b/osidb/tests/performance/utils.py new file mode 100644 index 000000000..e53330013 --- /dev/null +++ b/osidb/tests/performance/utils.py @@ -0,0 +1,303 @@ +import cProfile +import hashlib +import pstats +import re +from collections import defaultdict + + +def get_safe_filename(test_name): + """ + Converts a messy Pytest node name into a safe, valid filename. + Input: "test_endpoint[/api/v1/flaws?id=1]" + Output: "test_endpoint_api_v1_flaws_id_1_[hash]" + """ + # 1. Replace URL-like slashes and special chars with underscores + # Keep only alphanumerics, underscores, and hyphens + safe_name = re.sub(r"[^a-zA-Z0-9_\-]", "_", test_name) + + # 2. Collapse multiple underscores + safe_name = re.sub(r"_{2,}", "_", safe_name) + + # 3. Truncate if too long (OS limit is usually 255) + # If we cut it off, we add a hash of the original name to ensure uniqueness + if len(safe_name) > 100: + name_hash = hashlib.md5(test_name.encode("utf-8")).hexdigest()[:8] + safe_name = f"{safe_name[:100]}_{name_hash}" + + return safe_name + + +def clean_sql(sql: str) -> str: + """ + Removes pghistory context injection or other middleware noise + to reveal the actual application query. + + Input: "SELECT set_config('pghistory...', ...); SELECT ..." + Output: "SELECT ..." + """ + # Regex explanation: + # ^SELECT set_config : Must start with the config call + # \('pghistory\. : specific to pghistory to avoid false positives + # .*? : match arguments non-greedily + # \); : match the closing of the config and the semicolon + # \s* : remove trailing whitespace before next query + pghistory_pattern = r"^SELECT set_config\('pghistory\..*?\);\s*" + + # Remove the prefix + cleaned = re.sub(pghistory_pattern, "", sql, flags=re.DOTALL | re.IGNORECASE) + + return cleaned.strip() + + +def fingerprint_sql(sql: str) -> str: + """ + Normalizes a SQL query to a generic fingerprint to detect duplicates. + + Transformations: + 1. Replaces specific numeric values with '%d' + 2. Replaces quoted string LITERALS (single quotes) with '%s' + 3. Collapses IN clauses + 4. PRESERVES double-quoted identifiers (tables/columns) + """ + # 1. Replace hex/binary blobs (e.g. x'05A...') + sql = re.sub(r"x'[0-9a-f]+'", "'%b'", sql) + + # 2. Collapse IN clauses: IN (1, 2, 3) -> IN (...) + sql = re.sub(r"\bIN\s*\([^\)]+\)", "IN (...)", sql, flags=re.IGNORECASE) + + # 3. Replace String Literals: 'hello' -> '%s' + # We match anything inside single quotes. + # Note: This handles escaped single quotes inside strings if they follow SQL standard ('') + sql = re.sub(r"'(?:''|[^'])*'", "'%s'", sql) + + # 4. Replace Numbers: 123 -> %d + # We use \b to ensure we don't break table names like "table_2" + # But we must be careful not to break UUID casts like '...':uuid + # The previous single-quote replacement handles the UUID value '...', + # so the ::uuid part remains as literal text, which is fine. + sql = re.sub(r"\b\d+\b", "%d", sql) + + # 5. Whitespace cleanup + return " ".join(sql.split()) + + +def extract_tables(sql: str) -> list[str]: + """ + Extracts table names from a SQL query. + Matches: FROM "table_name", JOIN "table_name", UPDATE "table_name", etc. + """ + # This regex looks for keywords (FROM, JOIN, INTO, UPDATE) + # followed optionally by whitespace and then the table name (possibly quoted) + # It handles standard Django SQL generation. + pattern = r'(?:FROM|JOIN|UPDATE|INTO)\s+(?:"|`?)([a-zA-Z0-9_]+)(?:"|`?)' + + matches = re.findall(pattern, sql, flags=re.IGNORECASE) + + # Deduplicate tables within a single query (e.g. self-joins) + # or return all to weight complexity? Let's return unique per query. + return list(set(matches)) + + +def _empty_result(): + """Return empty result structure for edge cases.""" + return { + "total_time": 0, + "total_calls": 0, + "top_functions": [], + "category_breakdown": {}, + } + + +def _strip_dirs(filename): + """Strip common prefixes to shorten paths for readability.""" + # Remove absolute path prefix for the project + filename = filename.replace("/home/atinocom/Documents/work/osidb/", "") + + # Remove site-packages prefix + filename = re.sub(r".*/site-packages/", "", filename) + + # Remove usr/lib prefix + filename = re.sub(r"/usr/lib.*?/python\d+\.\d+/", "", filename) + + return filename + + +def _categorize_function(filename, func_name): + """ + Categorizes a function based on Django-specific patterns. + + Categories (in priority order): + - serializer: DRF serializers + - view: Django/DRF views and viewsets + - orm: Database operations (Django ORM, psycopg2) + - framework: Django/DRF framework code + - business_logic: Project-specific code + - stdlib: Python standard library + - other: Third-party libraries + """ + # 1. Serializers (check first, specific pattern) + if any( + pattern in filename + for pattern in ["/serializer.py", "/serializers.py", "/serializers/"] + ): + return "serializer" + if "rest_framework.serializers" in filename or "drf_spectacular" in filename: + return "serializer" + + # 2. Views + if any( + pattern in filename + for pattern in ["/api_views.py", "/views.py", "/viewsets.py"] + ): + return "view" + if any( + mod in filename + for mod in ["rest_framework.views", "rest_framework.viewsets", "django.views"] + ): + return "view" + + # 3. ORM (database operations) + if any( + pattern in filename for pattern in ["/django/db/", "/psycopg2", "/postgresql"] + ): + return "orm" + if ( + "django.db.models" in filename + or "django.db.backends" in filename + or "psycopg2" in filename + ): + return "orm" + if any(fn in func_name.lower() for fn in ["execute", "fetch", "cursor"]): + return "orm" + + # 4. Framework (Django/DRF internals) + if any( + pattern in filename + for pattern in ["/django/", "/rest_framework/", "/celery/", "/redis/"] + ): + return "framework" + + # 5. Business Logic (project-specific code) + if any(pattern in filename for pattern in ["/osidb/", "/apps/", "/collectors/"]): + if not any( + exclude in filename for exclude in ["/tests/", "/migrations/", "/test_"] + ): + return "business_logic" + + # 6. Stdlib (Python standard library) + if any( + pattern in filename + for pattern in ["/usr/lib/python", "/lib/python", ""] + ): + return "stdlib" + + # 7. Default to other + return "other" + + +def _should_include(func_data, total_time, category): + """ + Determines if a function should be included in the report. + + Rules: + - Always include: serializer, orm, view, business_logic, framework + - Filter unless >5% total time: stdlib, other + """ + # Always include Django-specific categories + if category in ["serializer", "orm", "view", "business_logic", "framework"]: + return True + + # Filter stdlib/other unless >5% threshold + if category in ["stdlib", "other"]: + time_percent = ( + (func_data["total_time"] / total_time * 100) if total_time > 0 else 0 + ) + return time_percent >= 5.0 + + return True + + +def _aggregate_by_category(functions, total_time): + """Aggregate stats by category.""" + breakdown = defaultdict(lambda: {"time": 0, "calls": 0}) + + for func in functions: + cat = func["category"] + breakdown[cat]["time"] += func["total_time"] + breakdown[cat]["calls"] += func["call_count"] + + # Calculate percentages + for cat in breakdown: + breakdown[cat]["percent"] = ( + (breakdown[cat]["time"] / total_time * 100) if total_time > 0 else 0 + ) + + return dict(breakdown) + + +def get_profile_stats(profile: cProfile.Profile): + """ + Analyzes a cProfile profile with Django-specific categorization. + + Returns structured data for markdown reporting including: + - Top 10 time-consuming functions + - Category breakdown (serializer, orm, view, business_logic, framework, stdlib, other) + - Filtering of stdlib/other unless >5% of total time + """ + + # 2. Create pstats.Stats object + stats = pstats.Stats(profile) + total_time = stats.total_tt + + if total_time == 0: + return _empty_result() + + # 3. Extract and process all functions + all_functions = [] + for func_key, ( + prim_calls, + ncalls, + tottime, + cumtime, + callers, + ) in stats.stats.items(): + # Parse function key + filename = func_key[0] if len(func_key) > 0 else "" + line_number = func_key[1] if len(func_key) > 1 else 0 + func_name = func_key[2] if len(func_key) > 2 else "" + + # Categorize + category = _categorize_function(filename, func_name) + + # Build function data dict + func_data = { + "function_name": func_name, + "filename": _strip_dirs(filename), + "line_number": line_number, + "category": category, + "call_count": ncalls, + "total_time": tottime, + "cumulative_time": cumtime, + "time_percent": (tottime / total_time * 100) if total_time > 0 else 0, + "per_call_time": tottime / ncalls if ncalls > 0 else 0, + } + + # Apply filtering + if _should_include(func_data, total_time, category): + all_functions.append(func_data) + + # 4. Select top 10 + top_functions = sorted(all_functions, key=lambda x: x["total_time"], reverse=True)[ + :10 + ] + + # 5. Calculate category breakdown + category_breakdown = _aggregate_by_category(all_functions, total_time) + + # 6. Return structured data + return { + "total_time": total_time, + "total_calls": stats.total_calls, + "top_functions": top_functions, + "category_breakdown": category_breakdown, + } diff --git a/pytest.ini b/pytest.ini index 1504e23be..8ac077305 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,6 +10,7 @@ markers = enable_signals: enables django signals to run. enable_rls: enables row-level-security in the database during testing. queryset: marks a database query count test. + perf: marks a performance test filterwarnings = error diff --git a/tox.ini b/tox.ini index 2ec2514e6..6a58fc9a5 100644 --- a/tox.ini +++ b/tox.ini @@ -15,6 +15,9 @@ runner = uv-venv-runner # Exists as a workaround to uv-venv-lock-runner running sync with --lock while we need --frozen. # Set "commands_pre = " to prevent the sync (useful to environments that only require a small selection of packages) commands_pre = uv sync --frozen --active +[testenv:perf-tests] +commands = + pytest --no-cov -m "perf" {posargs} [testenv:queryset-tests] commands = @@ -82,4 +85,4 @@ commands = uvx {[ruff]ruff_version} check --fix --diff --select I . [testenv:ruff-format] commands_pre = -commands = uvx {[ruff]ruff_version} format --check . \ No newline at end of file +commands = uvx {[ruff]ruff_version} format --check .