Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions client/src/components/column.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,11 @@ export function NumberRangeColumn<Obj extends Entity>(props: NumberColumnProps<O

export function CustomFieldColumn<Obj extends Entity>(props: Omit<BaseColumnProps<Obj>, "id"> & { field: Field }) {
const field = props.field;
const dataId = ("extra." + field.key) as keyof Obj & string;
const commonProps = {
...props,
id: ["extra", field.key],
dataId,
title: field.name,
sorter: false,
transform: (value: unknown) => {
Expand All @@ -404,6 +406,12 @@ export function CustomFieldColumn<Obj extends Entity>(props: Omit<BaseColumnProp
},
};

const buildFilterProps = (filterItems: ColumnFilterItem[]) => {
const typedFilters = typeFilters<Obj>(props.tableState.filters);
const filteredValue = getFiltersForField(typedFilters, dataId);
return { filters: filterItems, filteredValue };
};

if (field.field_type === FieldType.integer) {
return NumberColumn({
...commonProps,
Expand Down Expand Up @@ -439,8 +447,15 @@ export function CustomFieldColumn<Obj extends Entity>(props: Omit<BaseColumnProp
...commonProps,
});
} else if (field.field_type === FieldType.boolean) {
const filterItems: ColumnFilterItem[] = [
{ text: props.t("yes"), value: '"true"' },
{ text: props.t("no"), value: '"false"' },
{ text: "<empty>", value: "<empty>" },
];
return Column({
...commonProps,
...buildFilterProps(filterItems),
allowMultipleFilters: false,
render: (rawValue) => {
const value = commonProps.transform ? commonProps.transform(rawValue) : rawValue;
let text;
Expand All @@ -455,16 +470,28 @@ export function CustomFieldColumn<Obj extends Entity>(props: Omit<BaseColumnProp
},
});
} else if (field.field_type === FieldType.choice && !field.multi_choice) {
const filterItems: ColumnFilterItem[] = (field.choices ?? []).map((choice) => ({
text: choice,
value: '"' + JSON.stringify(choice) + '"',
}));
filterItems.push({ text: "<empty>", value: "<empty>" });
return Column({
...commonProps,
...buildFilterProps(filterItems),
render: (rawValue) => {
const value = commonProps.transform ? commonProps.transform(rawValue) : rawValue;
return <TextField value={value} />;
},
});
} else if (field.field_type === FieldType.choice && field.multi_choice) {
const filterItems: ColumnFilterItem[] = (field.choices ?? []).map((choice) => ({
text: choice,
value: choice,
}));
filterItems.push({ text: "<empty>", value: "<empty>" });
return Column({
...commonProps,
...buildFilterProps(filterItems),
render: (rawValue) => {
const value = commonProps.transform ? commonProps.transform(rawValue) : rawValue;
return <TextField value={(value as string[] | undefined)?.join(", ")} />;
Expand Down
10 changes: 9 additions & 1 deletion spoolman/api/v1/filament.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from typing import Annotated

from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator, model_validator
Expand Down Expand Up @@ -201,6 +201,7 @@ def prevent_none(cls: type["FilamentUpdateParameters"], v: float | None) -> floa
)
async def find(
*,
request: Request,
db: Annotated[AsyncSession, Depends(get_db_session)],
vendor_name_old: Annotated[
str | None,
Expand Down Expand Up @@ -332,6 +333,12 @@ async def find(
else:
vendor_ids = None

extra_filters = {
key.removeprefix("extra."): value
for key, value in request.query_params.items()
if key.startswith("extra.") and key != "extra."
}

if color_hex is not None:
matched_filaments = await filament.find_by_color(
db=db,
Expand All @@ -351,6 +358,7 @@ async def find(
material=material,
article_number=article_number,
external_id=external_id,
extra=extra_filters or None,
sort_by=sort_by,
limit=limit,
offset=offset,
Expand Down
10 changes: 9 additions & 1 deletion spoolman/api/v1/spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime
from typing import Annotated

from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator
Expand Down Expand Up @@ -127,6 +127,7 @@ class SpoolMeasureParameters(BaseModel):
)
async def find(
*,
request: Request,
db: Annotated[AsyncSession, Depends(get_db_session)],
filament_name_old: Annotated[
str | None,
Expand Down Expand Up @@ -285,6 +286,12 @@ async def find(
else:
filament_vendor_ids = None

extra_filters = {
key.removeprefix("extra."): value
for key, value in request.query_params.items()
if key.startswith("extra.") and key != "extra."
}

db_items, total_count = await spool.find(
db=db,
filament_name=filament_name if filament_name is not None else filament_name_old,
Expand All @@ -295,6 +302,7 @@ async def find(
location=location,
lot_nr=lot_nr,
allow_archived=allow_archived,
extra=extra_filters or None,
sort_by=sort_by,
limit=limit,
offset=offset,
Expand Down
10 changes: 9 additions & 1 deletion spoolman/api/v1/vendor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from typing import Annotated

from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator
Expand Down Expand Up @@ -79,6 +79,7 @@ def prevent_none(cls: type["VendorUpdateParameters"], v: str | None) -> str | No
},
)
async def find(
request: Request,
db: Annotated[AsyncSession, Depends(get_db_session)],
name: Annotated[
str | None,
Expand Down Expand Up @@ -124,10 +125,17 @@ async def find(
field, direction = sort_item.split(":")
sort_by[field] = SortOrder[direction.upper()]

extra_filters = {
key.removeprefix("extra."): value
for key, value in request.query_params.items()
if key.startswith("extra.") and key != "extra."
}

db_items, total_count = await vendor.find(
db=db,
name=name,
external_id=external_id,
extra=extra_filters or None,
sort_by=sort_by,
limit=limit,
offset=offset,
Expand Down
7 changes: 7 additions & 0 deletions spoolman/database/filament.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from spoolman.database import models, vendor
from spoolman.database.utils import (
SortOrder,
add_where_clause_extra_field,
add_where_clause_int_in,
add_where_clause_int_opt,
add_where_clause_str,
Expand Down Expand Up @@ -102,6 +103,7 @@ async def find(
material: str | None = None,
article_number: str | None = None,
external_id: str | None = None,
extra: dict[str, str] | None = None,
sort_by: dict[str, SortOrder] | None = None,
limit: int | None = None,
offset: int = 0,
Expand All @@ -127,6 +129,11 @@ async def find(
stmt = add_where_clause_str_opt(stmt, models.Filament.article_number, article_number)
stmt = add_where_clause_str_opt(stmt, models.Filament.external_id, external_id)

if extra:
stmt = add_where_clause_extra_field(
stmt, models.FilamentField, models.FilamentField.filament_id, models.Filament.id, extra,
)

total_count = None

if limit is not None:
Expand Down
7 changes: 7 additions & 0 deletions spoolman/database/spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from spoolman.database import filament, models
from spoolman.database.utils import (
SortOrder,
add_where_clause_extra_field,
add_where_clause_int,
add_where_clause_int_opt,
add_where_clause_str,
Expand Down Expand Up @@ -122,6 +123,7 @@ async def find( # noqa: C901, PLR0912
location: str | None = None,
lot_nr: str | None = None,
allow_archived: bool = False,
extra: dict[str, str] | None = None,
sort_by: dict[str, SortOrder] | None = None,
limit: int | None = None,
offset: int = 0,
Expand All @@ -148,6 +150,11 @@ async def find( # noqa: C901, PLR0912
stmt = add_where_clause_str_opt(stmt, models.Spool.location, location)
stmt = add_where_clause_str_opt(stmt, models.Spool.lot_nr, lot_nr)

if extra:
stmt = add_where_clause_extra_field(
stmt, models.SpoolField, models.SpoolField.spool_id, models.Spool.id, extra,
)

if not allow_archived:
# Since the archived field is nullable, and default is false, we need to check for both false or null
stmt = stmt.where(
Expand Down
34 changes: 34 additions & 0 deletions spoolman/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,37 @@ def add_where_clause_int_in(
if value is not None:
stmt = stmt.where(field.in_(value))
return stmt


def add_where_clause_extra_field(
stmt: Select,
field_model: type[models.Base],
field_fk_column: attributes.InstrumentedAttribute,
entity_pk_column: attributes.InstrumentedAttribute,
extra_filters: dict[str, str],
) -> Select:
"""Add where clauses to filter by extra field key/value pairs using EXISTS subqueries."""
for field_key, search_value in extra_filters.items():
value_conditions = []
has_empty = False

for value_part in search_value.split(","):
if len(value_part) == 0:
has_empty = True
elif value_part[0] == '"' and value_part[-1] == '"':
value_conditions.append(field_model.value == value_part[1:-1])
else:
value_conditions.append(field_model.value.ilike(f"%{value_part}%"))

base_where = sqlalchemy.and_(field_fk_column == entity_pk_column, field_model.key == field_key)

conditions = []
if value_conditions:
conditions.append(sqlalchemy.exists().where(sqlalchemy.and_(base_where, sqlalchemy.or_(*value_conditions))))
if has_empty:
conditions.append(~sqlalchemy.exists().where(base_where))

if conditions:
stmt = stmt.where(sqlalchemy.or_(*conditions))

return stmt
13 changes: 12 additions & 1 deletion spoolman/database/vendor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from spoolman.api.v1.models import EventType, Vendor, VendorEvent
from spoolman.database import models
from spoolman.database.utils import SortOrder, add_where_clause_str, add_where_clause_str_opt
from spoolman.database.utils import (
SortOrder,
add_where_clause_extra_field,
add_where_clause_str,
add_where_clause_str_opt,
)
from spoolman.exceptions import ItemNotFoundError
from spoolman.ws import websocket_manager

Expand Down Expand Up @@ -53,6 +58,7 @@ async def find(
db: AsyncSession,
name: str | None = None,
external_id: str | None = None,
extra: dict[str, str] | None = None,
sort_by: dict[str, SortOrder] | None = None,
limit: int | None = None,
offset: int = 0,
Expand All @@ -66,6 +72,11 @@ async def find(
stmt = add_where_clause_str(stmt, models.Vendor.name, name)
stmt = add_where_clause_str_opt(stmt, models.Vendor.external_id, external_id)

if extra:
stmt = add_where_clause_extra_field(
stmt, models.VendorField, models.VendorField.vendor_id, models.Vendor.id, extra,
)

total_count = None

if limit is not None:
Expand Down
72 changes: 72 additions & 0 deletions tests_integration/tests/filament/test_find.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Integration tests for the Filament API endpoint."""

import json
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
Expand Down Expand Up @@ -507,3 +508,74 @@ def test_find_filaments_by_similar_color_100(filaments: Fixture):
filaments_result,
[filaments.filaments[0], filaments.filaments[1], filaments.filaments[2]],
)


@dataclass
class ExtraFieldFixture:
filaments: list[dict[str, Any]]


@pytest.fixture(scope="module")
def filaments_with_extra() -> Iterable[ExtraFieldFixture]:
"""Add filaments with extra fields to the database."""
# Create extra field definition
result = httpx.post(
f"{URL}/api/v1/field/filament/tag",
json={"name": "Tag", "field_type": "text", "order": 0},
)
result.raise_for_status()

result = httpx.post(
f"{URL}/api/v1/filament",
json={
"name": "ExtraFilament1",
"density": 1.25,
"diameter": 1.75,
"extra": {"tag": json.dumps("production")},
},
)
result.raise_for_status()
filament_1 = result.json()

result = httpx.post(
f"{URL}/api/v1/filament",
json={
"name": "ExtraFilament2",
"density": 1.25,
"diameter": 1.75,
},
)
result.raise_for_status()
filament_2 = result.json()

yield ExtraFieldFixture(filaments=[filament_1, filament_2])

httpx.delete(f"{URL}/api/v1/filament/{filament_1['id']}").raise_for_status()
httpx.delete(f"{URL}/api/v1/filament/{filament_2['id']}").raise_for_status()
httpx.delete(f"{URL}/api/v1/field/filament/tag").raise_for_status()


def test_find_filaments_by_extra_field(filaments_with_extra: ExtraFieldFixture):
"""Test filtering filaments by extra field value."""
result = httpx.get(
f"{URL}/api/v1/filament",
params={"extra.tag": "production", "name": "ExtraFilament"},
)
result.raise_for_status()

filaments_result = result.json()
assert len(filaments_result) == 1
assert filaments_result[0]["id"] == filaments_with_extra.filaments[0]["id"]


def test_find_filaments_by_extra_field_empty(filaments_with_extra: ExtraFieldFixture):
"""Test filtering filaments that do not have the extra field set."""
result = httpx.get(
f"{URL}/api/v1/filament",
params={"extra.tag": "", "name": "ExtraFilament"},
)
result.raise_for_status()

filaments_result = result.json()
assert len(filaments_result) == 1
assert filaments_result[0]["id"] == filaments_with_extra.filaments[1]["id"]
Loading