diff --git a/client/src/components/column.tsx b/client/src/components/column.tsx index 059b607f0..58abb4945 100644 --- a/client/src/components/column.tsx +++ b/client/src/components/column.tsx @@ -391,9 +391,11 @@ export function NumberRangeColumn(props: NumberColumnProps(props: Omit, "id"> & { field: Field }) { const field = props.field; + const dataId = ("extra." + field.key) as keyof Obj & string; const commonProps = { ...props, id: ["extra", field.key], + dataId, title: field.name, sorter: false, transform: (value: unknown) => { @@ -404,6 +406,12 @@ export function CustomFieldColumn(props: Omit { + const typedFilters = typeFilters(props.tableState.filters); + const filteredValue = getFiltersForField(typedFilters, dataId); + return { filters: filterItems, filteredValue }; + }; + if (field.field_type === FieldType.integer) { return NumberColumn({ ...commonProps, @@ -439,8 +447,15 @@ export function CustomFieldColumn(props: Omit", value: "" }, + ]; return Column({ ...commonProps, + ...buildFilterProps(filterItems), + allowMultipleFilters: false, render: (rawValue) => { const value = commonProps.transform ? commonProps.transform(rawValue) : rawValue; let text; @@ -455,16 +470,28 @@ export function CustomFieldColumn(props: Omit ({ + text: choice, + value: '"' + JSON.stringify(choice) + '"', + })); + filterItems.push({ text: "", value: "" }); return Column({ ...commonProps, + ...buildFilterProps(filterItems), render: (rawValue) => { const value = commonProps.transform ? commonProps.transform(rawValue) : rawValue; return ; }, }); } else if (field.field_type === FieldType.choice && field.multi_choice) { + const filterItems: ColumnFilterItem[] = (field.choices ?? []).map((choice) => ({ + text: choice, + value: choice, + })); + filterItems.push({ text: "", value: "" }); return Column({ ...commonProps, + ...buildFilterProps(filterItems), render: (rawValue) => { const value = commonProps.transform ? commonProps.transform(rawValue) : rawValue; return ; diff --git a/spoolman/api/v1/filament.py b/spoolman/api/v1/filament.py index 3e3f859af..49f945e96 100644 --- a/spoolman/api/v1/filament.py +++ b/spoolman/api/v1/filament.py @@ -4,7 +4,7 @@ import logging from typing import Annotated -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator, model_validator @@ -201,6 +201,7 @@ def prevent_none(cls: type["FilamentUpdateParameters"], v: float | None) -> floa ) async def find( *, + request: Request, db: Annotated[AsyncSession, Depends(get_db_session)], vendor_name_old: Annotated[ str | None, @@ -332,6 +333,12 @@ async def find( else: vendor_ids = None + extra_filters = { + key.removeprefix("extra."): value + for key, value in request.query_params.items() + if key.startswith("extra.") and key != "extra." + } + if color_hex is not None: matched_filaments = await filament.find_by_color( db=db, @@ -351,6 +358,7 @@ async def find( material=material, article_number=article_number, external_id=external_id, + extra=extra_filters or None, sort_by=sort_by, limit=limit, offset=offset, diff --git a/spoolman/api/v1/spool.py b/spoolman/api/v1/spool.py index 8f667e3da..eb8f3663d 100644 --- a/spoolman/api/v1/spool.py +++ b/spoolman/api/v1/spool.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Annotated -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator @@ -127,6 +127,7 @@ class SpoolMeasureParameters(BaseModel): ) async def find( *, + request: Request, db: Annotated[AsyncSession, Depends(get_db_session)], filament_name_old: Annotated[ str | None, @@ -285,6 +286,12 @@ async def find( else: filament_vendor_ids = None + extra_filters = { + key.removeprefix("extra."): value + for key, value in request.query_params.items() + if key.startswith("extra.") and key != "extra." + } + db_items, total_count = await spool.find( db=db, filament_name=filament_name if filament_name is not None else filament_name_old, @@ -295,6 +302,7 @@ async def find( location=location, lot_nr=lot_nr, allow_archived=allow_archived, + extra=extra_filters or None, sort_by=sort_by, limit=limit, offset=offset, diff --git a/spoolman/api/v1/vendor.py b/spoolman/api/v1/vendor.py index 9216fba30..94845aada 100644 --- a/spoolman/api/v1/vendor.py +++ b/spoolman/api/v1/vendor.py @@ -3,7 +3,7 @@ import asyncio from typing import Annotated -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator @@ -79,6 +79,7 @@ def prevent_none(cls: type["VendorUpdateParameters"], v: str | None) -> str | No }, ) async def find( + request: Request, db: Annotated[AsyncSession, Depends(get_db_session)], name: Annotated[ str | None, @@ -124,10 +125,17 @@ async def find( field, direction = sort_item.split(":") sort_by[field] = SortOrder[direction.upper()] + extra_filters = { + key.removeprefix("extra."): value + for key, value in request.query_params.items() + if key.startswith("extra.") and key != "extra." + } + db_items, total_count = await vendor.find( db=db, name=name, external_id=external_id, + extra=extra_filters or None, sort_by=sort_by, limit=limit, offset=offset, diff --git a/spoolman/database/filament.py b/spoolman/database/filament.py index e2d742758..80fff7a1f 100644 --- a/spoolman/database/filament.py +++ b/spoolman/database/filament.py @@ -14,6 +14,7 @@ from spoolman.database import models, vendor from spoolman.database.utils import ( SortOrder, + add_where_clause_extra_field, add_where_clause_int_in, add_where_clause_int_opt, add_where_clause_str, @@ -102,6 +103,7 @@ async def find( material: str | None = None, article_number: str | None = None, external_id: str | None = None, + extra: dict[str, str] | None = None, sort_by: dict[str, SortOrder] | None = None, limit: int | None = None, offset: int = 0, @@ -127,6 +129,11 @@ async def find( stmt = add_where_clause_str_opt(stmt, models.Filament.article_number, article_number) stmt = add_where_clause_str_opt(stmt, models.Filament.external_id, external_id) + if extra: + stmt = add_where_clause_extra_field( + stmt, models.FilamentField, models.FilamentField.filament_id, models.Filament.id, extra, + ) + total_count = None if limit is not None: diff --git a/spoolman/database/spool.py b/spoolman/database/spool.py index 5c190ce65..629d3aac7 100644 --- a/spoolman/database/spool.py +++ b/spoolman/database/spool.py @@ -15,6 +15,7 @@ from spoolman.database import filament, models from spoolman.database.utils import ( SortOrder, + add_where_clause_extra_field, add_where_clause_int, add_where_clause_int_opt, add_where_clause_str, @@ -122,6 +123,7 @@ async def find( # noqa: C901, PLR0912 location: str | None = None, lot_nr: str | None = None, allow_archived: bool = False, + extra: dict[str, str] | None = None, sort_by: dict[str, SortOrder] | None = None, limit: int | None = None, offset: int = 0, @@ -148,6 +150,11 @@ async def find( # noqa: C901, PLR0912 stmt = add_where_clause_str_opt(stmt, models.Spool.location, location) stmt = add_where_clause_str_opt(stmt, models.Spool.lot_nr, lot_nr) + if extra: + stmt = add_where_clause_extra_field( + stmt, models.SpoolField, models.SpoolField.spool_id, models.Spool.id, extra, + ) + if not allow_archived: # Since the archived field is nullable, and default is false, we need to check for both false or null stmt = stmt.where( diff --git a/spoolman/database/utils.py b/spoolman/database/utils.py index 2d8776c00..a02f7cdb5 100644 --- a/spoolman/database/utils.py +++ b/spoolman/database/utils.py @@ -129,3 +129,37 @@ def add_where_clause_int_in( if value is not None: stmt = stmt.where(field.in_(value)) return stmt + + +def add_where_clause_extra_field( + stmt: Select, + field_model: type[models.Base], + field_fk_column: attributes.InstrumentedAttribute, + entity_pk_column: attributes.InstrumentedAttribute, + extra_filters: dict[str, str], +) -> Select: + """Add where clauses to filter by extra field key/value pairs using EXISTS subqueries.""" + for field_key, search_value in extra_filters.items(): + value_conditions = [] + has_empty = False + + for value_part in search_value.split(","): + if len(value_part) == 0: + has_empty = True + elif value_part[0] == '"' and value_part[-1] == '"': + value_conditions.append(field_model.value == value_part[1:-1]) + else: + value_conditions.append(field_model.value.ilike(f"%{value_part}%")) + + base_where = sqlalchemy.and_(field_fk_column == entity_pk_column, field_model.key == field_key) + + conditions = [] + if value_conditions: + conditions.append(sqlalchemy.exists().where(sqlalchemy.and_(base_where, sqlalchemy.or_(*value_conditions)))) + if has_empty: + conditions.append(~sqlalchemy.exists().where(base_where)) + + if conditions: + stmt = stmt.where(sqlalchemy.or_(*conditions)) + + return stmt diff --git a/spoolman/database/vendor.py b/spoolman/database/vendor.py index f2e83018e..ce37a1918 100644 --- a/spoolman/database/vendor.py +++ b/spoolman/database/vendor.py @@ -9,7 +9,12 @@ from spoolman.api.v1.models import EventType, Vendor, VendorEvent from spoolman.database import models -from spoolman.database.utils import SortOrder, add_where_clause_str, add_where_clause_str_opt +from spoolman.database.utils import ( + SortOrder, + add_where_clause_extra_field, + add_where_clause_str, + add_where_clause_str_opt, +) from spoolman.exceptions import ItemNotFoundError from spoolman.ws import websocket_manager @@ -53,6 +58,7 @@ async def find( db: AsyncSession, name: str | None = None, external_id: str | None = None, + extra: dict[str, str] | None = None, sort_by: dict[str, SortOrder] | None = None, limit: int | None = None, offset: int = 0, @@ -66,6 +72,11 @@ async def find( stmt = add_where_clause_str(stmt, models.Vendor.name, name) stmt = add_where_clause_str_opt(stmt, models.Vendor.external_id, external_id) + if extra: + stmt = add_where_clause_extra_field( + stmt, models.VendorField, models.VendorField.vendor_id, models.Vendor.id, extra, + ) + total_count = None if limit is not None: diff --git a/tests_integration/tests/filament/test_find.py b/tests_integration/tests/filament/test_find.py index 44299aef4..b9106a99b 100644 --- a/tests_integration/tests/filament/test_find.py +++ b/tests_integration/tests/filament/test_find.py @@ -1,5 +1,6 @@ """Integration tests for the Filament API endpoint.""" +import json from collections.abc import Iterable from dataclasses import dataclass from typing import Any @@ -507,3 +508,74 @@ def test_find_filaments_by_similar_color_100(filaments: Fixture): filaments_result, [filaments.filaments[0], filaments.filaments[1], filaments.filaments[2]], ) + + +@dataclass +class ExtraFieldFixture: + filaments: list[dict[str, Any]] + + +@pytest.fixture(scope="module") +def filaments_with_extra() -> Iterable[ExtraFieldFixture]: + """Add filaments with extra fields to the database.""" + # Create extra field definition + result = httpx.post( + f"{URL}/api/v1/field/filament/tag", + json={"name": "Tag", "field_type": "text", "order": 0}, + ) + result.raise_for_status() + + result = httpx.post( + f"{URL}/api/v1/filament", + json={ + "name": "ExtraFilament1", + "density": 1.25, + "diameter": 1.75, + "extra": {"tag": json.dumps("production")}, + }, + ) + result.raise_for_status() + filament_1 = result.json() + + result = httpx.post( + f"{URL}/api/v1/filament", + json={ + "name": "ExtraFilament2", + "density": 1.25, + "diameter": 1.75, + }, + ) + result.raise_for_status() + filament_2 = result.json() + + yield ExtraFieldFixture(filaments=[filament_1, filament_2]) + + httpx.delete(f"{URL}/api/v1/filament/{filament_1['id']}").raise_for_status() + httpx.delete(f"{URL}/api/v1/filament/{filament_2['id']}").raise_for_status() + httpx.delete(f"{URL}/api/v1/field/filament/tag").raise_for_status() + + +def test_find_filaments_by_extra_field(filaments_with_extra: ExtraFieldFixture): + """Test filtering filaments by extra field value.""" + result = httpx.get( + f"{URL}/api/v1/filament", + params={"extra.tag": "production", "name": "ExtraFilament"}, + ) + result.raise_for_status() + + filaments_result = result.json() + assert len(filaments_result) == 1 + assert filaments_result[0]["id"] == filaments_with_extra.filaments[0]["id"] + + +def test_find_filaments_by_extra_field_empty(filaments_with_extra: ExtraFieldFixture): + """Test filtering filaments that do not have the extra field set.""" + result = httpx.get( + f"{URL}/api/v1/filament", + params={"extra.tag": "", "name": "ExtraFilament"}, + ) + result.raise_for_status() + + filaments_result = result.json() + assert len(filaments_result) == 1 + assert filaments_result[0]["id"] == filaments_with_extra.filaments[1]["id"] diff --git a/tests_integration/tests/spool/test_find.py b/tests_integration/tests/spool/test_find.py index 51a275867..4f9406f39 100644 --- a/tests_integration/tests/spool/test_find.py +++ b/tests_integration/tests/spool/test_find.py @@ -1,5 +1,6 @@ """Integration tests for the Spool API endpoint.""" +import json from collections.abc import Iterable from dataclasses import dataclass from typing import Any @@ -485,3 +486,72 @@ def test_find_spools_by_empty_lot_nr(spools: Fixture): # Verify spools_result = result.json() assert_lists_compatible(spools_result, (spools.spools[3], spools.spools[4])) + + +@dataclass +class ExtraFieldFixture: + spools: list[dict[str, Any]] + + +@pytest.fixture(scope="module") +def spools_with_extra(random_filament_mod: dict[str, Any]) -> Iterable[ExtraFieldFixture]: + """Add spools with extra fields to the database.""" + # Create extra field definition + result = httpx.post( + f"{URL}/api/v1/field/spool/tag", + json={"name": "Tag", "field_type": "text", "order": 0}, + ) + result.raise_for_status() + + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament_mod["id"], + "location": "ExtraSpoolLoc", + "extra": {"tag": json.dumps("production")}, + }, + ) + result.raise_for_status() + spool_1 = result.json() + + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament_mod["id"], + "location": "ExtraSpoolLoc", + }, + ) + result.raise_for_status() + spool_2 = result.json() + + yield ExtraFieldFixture(spools=[spool_1, spool_2]) + + httpx.delete(f"{URL}/api/v1/spool/{spool_1['id']}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_2['id']}").raise_for_status() + httpx.delete(f"{URL}/api/v1/field/spool/tag").raise_for_status() + + +def test_find_spools_by_extra_field(spools_with_extra: ExtraFieldFixture): + """Test filtering spools by extra field value.""" + result = httpx.get( + f"{URL}/api/v1/spool", + params={"extra.tag": "production", "location": "ExtraSpoolLoc", "allow_archived": True}, + ) + result.raise_for_status() + + spools_result = result.json() + assert len(spools_result) == 1 + assert spools_result[0]["id"] == spools_with_extra.spools[0]["id"] + + +def test_find_spools_by_extra_field_empty(spools_with_extra: ExtraFieldFixture): + """Test filtering spools that do not have the extra field set.""" + result = httpx.get( + f"{URL}/api/v1/spool", + params={"extra.tag": "", "location": "ExtraSpoolLoc", "allow_archived": True}, + ) + result.raise_for_status() + + spools_result = result.json() + assert len(spools_result) == 1 + assert spools_result[0]["id"] == spools_with_extra.spools[1]["id"] diff --git a/tests_integration/tests/vendor/test_find.py b/tests_integration/tests/vendor/test_find.py index 2649a6d5b..99f57be98 100644 --- a/tests_integration/tests/vendor/test_find.py +++ b/tests_integration/tests/vendor/test_find.py @@ -1,5 +1,6 @@ """Integration tests for the Vendor API endpoint.""" +import json from collections.abc import Iterable from dataclasses import dataclass from typing import Any @@ -211,3 +212,103 @@ def test_find_vendors_by_empty_external_id(vendors: Fixture): # Verify vendors_result = result.json() assert_lists_compatible(vendors_result, [vendors.vendors[0], vendors.vendors[2]]) + + +@dataclass +class ExtraFieldFixture: + vendors: list[dict[str, Any]] + + +@pytest.fixture(scope="module") +def vendors_with_extra() -> Iterable[ExtraFieldFixture]: + """Add vendors with extra fields to the database.""" + # Create extra field definition + result = httpx.post( + f"{URL}/api/v1/field/vendor/tag", + json={"name": "Tag", "field_type": "text", "order": 0}, + ) + result.raise_for_status() + + # Create vendors with extra field values + result = httpx.post( + f"{URL}/api/v1/vendor", + json={"name": "ExtraVendor1", "extra": {"tag": json.dumps("production")}}, + ) + result.raise_for_status() + vendor_1 = result.json() + + result = httpx.post( + f"{URL}/api/v1/vendor", + json={"name": "ExtraVendor2", "extra": {"tag": json.dumps("testing")}}, + ) + result.raise_for_status() + vendor_2 = result.json() + + result = httpx.post( + f"{URL}/api/v1/vendor", + json={"name": "ExtraVendor3"}, + ) + result.raise_for_status() + vendor_3 = result.json() + + yield ExtraFieldFixture(vendors=[vendor_1, vendor_2, vendor_3]) + + httpx.delete(f"{URL}/api/v1/vendor/{vendor_1['id']}").raise_for_status() + httpx.delete(f"{URL}/api/v1/vendor/{vendor_2['id']}").raise_for_status() + httpx.delete(f"{URL}/api/v1/vendor/{vendor_3['id']}").raise_for_status() + httpx.delete(f"{URL}/api/v1/field/vendor/tag").raise_for_status() + + +def test_find_vendors_by_extra_field(vendors_with_extra: ExtraFieldFixture): + """Test filtering vendors by extra field value (partial match).""" + result = httpx.get( + f"{URL}/api/v1/vendor", + params={"extra.tag": "production", "name": "ExtraVendor"}, + ) + result.raise_for_status() + + vendors_result = result.json() + assert len(vendors_result) == 1 + assert vendors_result[0]["id"] == vendors_with_extra.vendors[0]["id"] + + +def test_find_vendors_by_extra_field_exact(vendors_with_extra: ExtraFieldFixture): + """Test filtering vendors by extra field value (exact match). + + Extra field values are JSON-encoded, so text "production" is stored as '"production"'. + To exact-match, we need to include the JSON quotes in the search term. + """ + result = httpx.get( + f"{URL}/api/v1/vendor", + params={"extra.tag": '""production""', "name": "ExtraVendor"}, + ) + result.raise_for_status() + + vendors_result = result.json() + assert len(vendors_result) == 1 + assert vendors_result[0]["id"] == vendors_with_extra.vendors[0]["id"] + + +def test_find_vendors_by_extra_field_no_match(vendors_with_extra: ExtraFieldFixture): # noqa: ARG001 + """Test filtering vendors by extra field value with no match.""" + result = httpx.get( + f"{URL}/api/v1/vendor", + params={"extra.tag": "nonexistent", "name": "ExtraVendor"}, + ) + result.raise_for_status() + + vendors_result = result.json() + assert len(vendors_result) == 0 + + +def test_find_vendors_by_extra_field_empty(vendors_with_extra: ExtraFieldFixture): + """Test filtering vendors that do not have the extra field set.""" + result = httpx.get( + f"{URL}/api/v1/vendor", + params={"extra.tag": "", "name": "ExtraVendor"}, + ) + result.raise_for_status() + + vendors_result = result.json() + assert len(vendors_result) == 1 + assert vendors_result[0]["id"] == vendors_with_extra.vendors[2]["id"]