diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d66ec1..7292c17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + ## [0.1.0-rev3] ### Added + - **Mining Performance Analysis Domain** (`edge_mining/domain/performance/`): - Value objects: `MiningReward`, `PoolWorkerStats`, `PoolStats`, `PayoutSchedule` in `value_objects.py` (renamed from misspelled `values_objects.py`) - Entity `MiningSession` for tracking aggregated mining activity (`entities.py`) @@ -61,12 +64,49 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Tests** — 107 new unit tests across tracker adapters, configuration service, REST router, and CLI commands - **Tests** — 15 new unit tests for `CachedRateLimitedTrackerBase` (cache hit/miss, TTL expiry, backoff progression, stale-while-error, retry-after handling, cache invalidation) plus 429-detection tests for the Ocean and Braiins adapters +- **Home Load — Phase 3: DecisionalContext Integration** + - Extended field resolver (`helpers.py`) with dict key lookup for `home_load.devices..*` paths and `None` guard for `Optional` intermediate fields + - Pre-computed window properties on `LoadEnergyConsumption`: `next_1h`, `next_2h`, `next_4h`, `last_1h`, `last_4h`, `last_24h` + - Example YAML rules: `home_load_start_rules.yaml` (3 rules), `home_load_stop_rules.yaml` (4 rules) + - Fixed existing rules from `home_load_forecast` → `home_load.total_forecast.next_2h.avg_power` + - 5 new unit tests for dict resolver + +- **Home Load — Phase 4: ML Forecast Providers (Statsmodels + XGBoost)** + - ML optional dependencies: `scikit-learn>=1.5.0`, `statsmodels>=0.14.0`, `xgboost>=2.0.0` in `[ml]` extras + - `EnergyLoadForecastProviderAdapter.STATSMODELS` and `.XGBOOST` enum values + - Config dataclasses: `EnergyLoadForecastProviderStatsmodelsConfig`, `EnergyLoadForecastProviderXGBoostConfig` + - Feature engineering utilities (`features.py`): `intervals_to_hourly_series()`, `fill_missing_hours()`, `build_calendar_features()`, `build_lag_features()`, `prepare_supervised_dataset()` + - `LoadConsumptionModel` entity with `model_bytes` (serialized pickle), MAE/RMSE metrics, `is_active` flag + - `LoadConsumptionModelRepository` port + three implementations: InMemory, SQLite, SQLAlchemy + - `load_consumption_models` database table with composite index on `(adapter_type, device_id, is_active)` + - Alembic migration `c3d4e5f6a7b8` for `load_consumption_models` table + - `StatsmodelsForecastProvider` (Holt-Winters exponential smoothing) with factory, lazy import + - `XGBoostForecastProvider` (gradient boosting with calendar + lag features) with factory, lazy import + - `LoadForecastModelTrainingService`: nightly batch training with holdout evaluation + best model promotion + - Pydantic schemas: `EnergyLoadForecastProviderStatsmodelsConfigSchema`, `EnergyLoadForecastProviderXGBoostConfigSchema` + - Scheduler cron job at 04:00 for nightly ML model training + +- **Home Load — API Completion** + - `ConfigurationServiceInterface`: 10 new abstract CRUD methods for `EnergyLoadForecastProvider` (5) and `EnergyLoadHistoryProvider` (5) + - `ConfigurationService`: implemented `add_`, `get_`, `list_`, `update_`, `remove_energy_load_forecast_provider` + - Completed 5 forecast provider REST endpoints (previously stubs returning 501/empty): + - `GET /energy-load-forecast-providers` — list all providers + - `POST /energy-load-forecast-providers` — create and persist provider + - `GET /energy-load-forecast-providers/{id}` — get provider by ID + - `PUT /energy-load-forecast-providers/{id}` — update provider with config deserialization + - `DELETE /energy-load-forecast-providers/{id}` — remove provider + ### Changed - `MiningPerformanceTrackerPort` methods are now `async`; the dummy tracker adapter has been adapted accordingly - `OptimizationService` now awaits `get_current_hashrate` calls to match the async port contract, and consolidates the three tracker calls behind a new private helper `_build_mining_performance_snapshot` that returns a single `MiningPerformanceSnapshot` - Replaced `DecisionalContext.tracker_current_hashrate: Optional[HashRate]` with `mining_performance: Optional[MiningPerformanceSnapshot]`; `DecisionalContextSchema` and the rule engine `OPERATOR_EXAMPLES[LTE]` example updated accordingly (new field path: `mining_performance.current_hashrate.value`) - Interactive CLI main menu: "Run all optimization units" shifted from option 8 to 9 to accommodate the new tracker menu at option 8 - Replaced per-module `_utc_now_timestamp()` helpers in `domain/performance/entities.py` and `domain/performance/value_objects.py` with the shared `utc_now_timestamp()` from `domain/common.py` +- `AdapterService`: new factory branches for STATSMODELS and XGBOOST with `model_repo` injection +- `PersistenceSettings`: added `load_consumption_model_repo` field +- `Services` dataclass: added `load_forecast_training_service` field +- `AutomationScheduler`: accepts optional `load_forecast_training_service`, schedules nightly training +- `bootstrap.py`: `LoadConsumptionModelRepository` wired in all three persistence branches (InMemory/SQLite/SQLAlchemy) ### Fixed - Replaced latent `default_factory=Timestamp(datetime.now())` bugs (which froze a single timestamp at class-definition time) with the proper callable `utc_now_timestamp`, producing a fresh timestamp per instance diff --git a/alembic.ini b/alembic.ini index 6c274a2..d7e9248 100644 --- a/alembic.ini +++ b/alembic.ini @@ -86,7 +86,7 @@ path_separator = os # database URL. This is consumed by the user-maintained env.py script only. # other means of configuring database URLs may be customized within the env.py # file. -sqlalchemy.url = sqlite:///./edgemining.db +sqlalchemy.url = sqlite:///data/db/edgemining.db [post_write_hooks] diff --git a/alembic/versions/4e55fe6113c7_initial_schema_with_all_tables.py b/alembic/versions/4e55fe6113c7_initial_schema_with_all_tables.py index e67f84f..c151755 100644 --- a/alembic/versions/4e55fe6113c7_initial_schema_with_all_tables.py +++ b/alembic/versions/4e55fe6113c7_initial_schema_with_all_tables.py @@ -98,12 +98,12 @@ def upgrade() -> None: ) op.create_index(op.f("ix_forecast_providers_id"), "forecast_providers", ["id"], unique=False) op.create_table( - "home_forecast_providers", + "energy_load_forecast_providers", sa.Column("id", sa.String(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("adapter_type", sa.String(), nullable=False), sa.Column( - "config", edge_mining.adapters.domain.home_load.tables.HomeForecastProviderConfigType(), nullable=True + "config", edge_mining.adapters.domain.home_load.tables.EnergyLoadForecastProviderConfigType(), nullable=True ), sa.Column("external_service_id", sa.String(), nullable=True), sa.ForeignKeyConstraint( @@ -112,7 +112,9 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f("ix_home_forecast_providers_id"), "home_forecast_providers", ["id"], unique=False) + op.create_index( + op.f("ix_energy_load_forecast_providers_id"), "energy_load_forecast_providers", ["id"], unique=False + ) op.create_table( "miner_controllers", sa.Column("id", sa.String(), nullable=False), @@ -206,7 +208,6 @@ def upgrade() -> None: "target_miner_ids", edge_mining.adapters.domain.optimization_unit.tables.EntityIdListType(), nullable=False ), sa.Column("energy_source_id", sa.String(), nullable=True), - sa.Column("home_forecast_provider_id", sa.String(), nullable=True), sa.Column("performance_tracker_id", sa.String(), nullable=True), sa.Column( "notifier_ids", edge_mining.adapters.domain.optimization_unit.tables.EntityIdListType(), nullable=False @@ -215,10 +216,6 @@ def upgrade() -> None: ["energy_source_id"], ["energy_sources.id"], ), - sa.ForeignKeyConstraint( - ["home_forecast_provider_id"], - ["home_forecast_providers.id"], - ), sa.ForeignKeyConstraint( ["performance_tracker_id"], ["mining_performance_trackers.id"], @@ -226,12 +223,27 @@ def upgrade() -> None: sa.PrimaryKeyConstraint("id"), ) op.create_index(op.f("ix_optimization_units_id"), "optimization_units", ["id"], unique=False) + op.create_table( + "home_load_power_points", + sa.Column("device_id", sa.String(), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("power", sa.Float(), nullable=False), + sa.PrimaryKeyConstraint("device_id", "timestamp"), + ) + op.create_index( + "ix_home_load_power_points_device_ts", + "home_load_power_points", + ["device_id", "timestamp"], + unique=False, + ) # ### end Alembic commands ### def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_home_load_power_points_device_ts", table_name="home_load_power_points") + op.drop_table("home_load_power_points") op.drop_index(op.f("ix_optimization_units_id"), table_name="optimization_units") op.drop_table("optimization_units") op.drop_index(op.f("ix_miners_id"), table_name="miners") @@ -244,8 +256,8 @@ def downgrade() -> None: op.drop_table("mining_performance_trackers") op.drop_index(op.f("ix_miner_controllers_id"), table_name="miner_controllers") op.drop_table("miner_controllers") - op.drop_index(op.f("ix_home_forecast_providers_id"), table_name="home_forecast_providers") - op.drop_table("home_forecast_providers") + op.drop_index(op.f("ix_energy_load_forecast_providers_id"), table_name="energy_load_forecast_providers") + op.drop_table("energy_load_forecast_providers") op.drop_index(op.f("ix_forecast_providers_id"), table_name="forecast_providers") op.drop_table("forecast_providers") op.drop_index(op.f("ix_energy_monitors_id"), table_name="energy_monitors") diff --git a/alembic/versions/b2c3d4e5f6a7_add_energy_load_history_providers.py b/alembic/versions/b2c3d4e5f6a7_add_energy_load_history_providers.py new file mode 100644 index 0000000..a003289 --- /dev/null +++ b/alembic/versions/b2c3d4e5f6a7_add_energy_load_history_providers.py @@ -0,0 +1,56 @@ +"""Add energy_load_history_providers table + +Revision ID: b2c3d4e5f6a7 +Revises: a1b2c3d4e5f6 +Create Date: 2026-04-22 10:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +import edge_mining.adapters.domain.home_load.tables +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b2c3d4e5f6a7" +down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add energy_load_history_providers table.""" + op.create_table( + "energy_load_history_providers", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("adapter_type", sa.String(), nullable=False), + sa.Column( + "config", + edge_mining.adapters.domain.home_load.tables.EnergyLoadHistoryProviderConfigType(), + nullable=True, + ), + sa.Column("external_service_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["external_service_id"], + ["external_services.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_energy_load_history_providers_id"), + "energy_load_history_providers", + ["id"], + unique=False, + ) + + +def downgrade() -> None: + """Remove energy_load_history_providers table.""" + op.drop_index( + op.f("ix_energy_load_history_providers_id"), + table_name="energy_load_history_providers", + ) + op.drop_table("energy_load_history_providers") diff --git a/alembic/versions/c3d4e5f6a7b8_add_load_consumption_models.py b/alembic/versions/c3d4e5f6a7b8_add_load_consumption_models.py new file mode 100644 index 0000000..614f4ca --- /dev/null +++ b/alembic/versions/c3d4e5f6a7b8_add_load_consumption_models.py @@ -0,0 +1,47 @@ +"""Add load_consumption_models table + +Revision ID: c3d4e5f6a7b8 +Revises: b2c3d4e5f6a7 +Create Date: 2026-04-22 14:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c3d4e5f6a7b8" +down_revision: Union[str, None] = "b2c3d4e5f6a7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "load_consumption_models", + sa.Column("id", sa.String(), nullable=False), + sa.Column("device_id", sa.String(), nullable=True), + sa.Column("adapter_type", sa.String(), nullable=False), + sa.Column("trained_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("mae", sa.Float(), nullable=True), + sa.Column("rmse", sa.Float(), nullable=True), + sa.Column("samples_used", sa.Integer(), nullable=False, server_default="0"), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="0"), + sa.Column("model_bytes", sa.LargeBinary(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_load_consumption_models_id", "load_consumption_models", ["id"]) + op.create_index( + "ix_load_consumption_models_active", + "load_consumption_models", + ["adapter_type", "device_id", "is_active"], + ) + + +def downgrade() -> None: + op.drop_index("ix_load_consumption_models_active", table_name="load_consumption_models") + op.drop_index("ix_load_consumption_models_id", table_name="load_consumption_models") + op.drop_table("load_consumption_models") diff --git a/alembic/versions/d4e5f6a7b8c9_add_tuning_and_backtesting_columns.py b/alembic/versions/d4e5f6a7b8c9_add_tuning_and_backtesting_columns.py new file mode 100644 index 0000000..6f381c4 --- /dev/null +++ b/alembic/versions/d4e5f6a7b8c9_add_tuning_and_backtesting_columns.py @@ -0,0 +1,35 @@ +"""Add tuning and backtesting columns to load_consumption_models + +Revision ID: d4e5f6a7b8c9 +Revises: c3d4e5f6a7b8 +Create Date: 2026-04-25 10:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d4e5f6a7b8c9" +down_revision: Union[str, None] = "c3d4e5f6a7b8" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table("load_consumption_models") as batch_op: + batch_op.add_column(sa.Column("tuning_params", sa.Text(), nullable=True)) + batch_op.add_column(sa.Column("backtest_mae", sa.Float(), nullable=True)) + batch_op.add_column(sa.Column("backtest_rmse", sa.Float(), nullable=True)) + batch_op.add_column(sa.Column("backtest_folds", sa.Integer(), nullable=False, server_default="0")) + + +def downgrade() -> None: + with op.batch_alter_table("load_consumption_models") as batch_op: + batch_op.drop_column("backtest_folds") + batch_op.drop_column("backtest_rmse") + batch_op.drop_column("backtest_mae") + batch_op.drop_column("tuning_params") diff --git a/alembic/versions/e5f6a7b8c9d0_add_home_loads_profile_to_optimization_units.py b/alembic/versions/e5f6a7b8c9d0_add_home_loads_profile_to_optimization_units.py new file mode 100644 index 0000000..a898849 --- /dev/null +++ b/alembic/versions/e5f6a7b8c9d0_add_home_loads_profile_to_optimization_units.py @@ -0,0 +1,29 @@ +"""Add home_loads_profile to optimization_units + +Revision ID: e5f6a7b8c9d0 +Revises: d4e5f6a7b8c9 +Create Date: 2026-04-29 12:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "e5f6a7b8c9d0" +down_revision: Union[str, None] = "d4e5f6a7b8c9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table("optimization_units") as batch_op: + batch_op.add_column(sa.Column("home_loads_profile", sa.String(), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("optimization_units") as batch_op: + batch_op.drop_column("home_loads_profile") diff --git a/data/examples/policies/550e8400-e29b-41d4-a716-446655440000.yaml b/data/examples/policies/550e8400-e29b-41d4-a716-446655440000.yaml index a1f1de0..bec94a5 100644 --- a/data/examples/policies/550e8400-e29b-41d4-a716-446655440000.yaml +++ b/data/examples/policies/550e8400-e29b-41d4-a716-446655440000.yaml @@ -119,7 +119,7 @@ stop_rules: description: Stop mining during high home energy demand periods conditions: all_of: - - field: home_load_forecast + - field: home_load.total_forecast.next_2h.avg_power operator: gt value: 2800 - field: timestamp.hour @@ -136,6 +136,6 @@ stop_rules: enabled: true metadata: author: Edge Mining User - version: 5 + version: 6 created: '2025-08-04' - last_modified: '2026-01-18' + last_modified: '2026-04-22' diff --git a/data/examples/rules/start/advanced_start_rules.yaml b/data/examples/rules/start/advanced_start_rules.yaml index 06745fa..6480f19 100644 --- a/data/examples/rules/start/advanced_start_rules.yaml +++ b/data/examples/rules/start/advanced_start_rules.yaml @@ -38,9 +38,9 @@ rules: - field: "energy_state.battery.state_of_charge" operator: "gt" value: 60 - - field: "home_load_forecast" + - field: "home_load.total_forecast.next_2h.avg_power" operator: "lt" - value: 2000 # Low expected home consumption + value: 2000 # Low expected home consumption (next 2h average) - any_of: - field: "energy_state.production" operator: "gt" diff --git a/data/examples/rules/start/home_load_start_rules.yaml b/data/examples/rules/start/home_load_start_rules.yaml new file mode 100644 index 0000000..9a676da --- /dev/null +++ b/data/examples/rules/start/home_load_start_rules.yaml @@ -0,0 +1,45 @@ +rules: + - name: "Low Home Load Forecast" + description: "Start mining when total household forecast is low for the next 2 hours" + conditions: + all_of: + - field: "home_load.total_forecast.next_2h.total_energy" + operator: "lt" + value: 1000 # Less than 1 kWh total in the next 2 hours + - field: "energy_state.battery.state_of_charge" + operator: "gt" + value: 40 + priority: 15 + enabled: true + + - name: "Low Boiler Forecast + Excess Solar" + description: "Start mining when boiler forecast is low and solar production is good" + conditions: + all_of: + - field: "home_load.devices.boiler.forecast.next_1h.peak_power" + operator: "lt" + value: 500 # Boiler expected to draw less than 500W in next hour + - field: "energy_state.production" + operator: "gt" + value: 2000 + - field: "energy_state.battery.state_of_charge" + operator: "gt" + value: 50 + priority: 12 + enabled: true + + - name: "Low Historical Load + Good Forecast" + description: "Start mining when recent load history is low and solar forecast is positive" + conditions: + all_of: + - field: "home_load.total_history.last_1h.avg_power" + operator: "lt" + value: 800 # Average power in the last hour below 800W + - field: "forecast.next_hour_power" + operator: "gt" + value: 1500 + - field: "energy_state.battery.state_of_charge" + operator: "gt" + value: 45 + priority: 10 + enabled: true diff --git a/data/examples/rules/stop/advanced_stop_rules.yaml b/data/examples/rules/stop/advanced_stop_rules.yaml index 173f872..847f720 100644 --- a/data/examples/rules/stop/advanced_stop_rules.yaml +++ b/data/examples/rules/stop/advanced_stop_rules.yaml @@ -48,9 +48,9 @@ rules: description: "Stop mining during high home energy demand periods" conditions: all_of: - - field: "home_load_forecast" + - field: "home_load.total_forecast.next_2h.avg_power" operator: "gt" - value: 2800 # High predicted home consumption + value: 2800 # High predicted home consumption (next 2h average) - field: "timestamp.hour" operator: "in" value: [17, 18, 19, 20, 21, 22] # Evening hours @@ -121,8 +121,8 @@ rules: - field: "forecast.next_hour_power" operator: "lt" value: 500 # Very low next hour forecast - - field: "home_load_forecast" + - field: "home_load.total_forecast.next_2h.avg_power" operator: "gt" - value: 2500 # High home consumption expected + value: 2500 # High home consumption expected (next 2h average) priority: 30 enabled: true diff --git a/data/examples/rules/stop/basic_stop_rules.yaml b/data/examples/rules/stop/basic_stop_rules.yaml index d02fbb2..92e9768 100644 --- a/data/examples/rules/stop/basic_stop_rules.yaml +++ b/data/examples/rules/stop/basic_stop_rules.yaml @@ -26,7 +26,7 @@ rules: description: "Stop mining when home load is high and battery is not full" conditions: all_of: - - field: "home_load_forecast" + - field: "home_load.total_forecast.next_2h.avg_power" operator: "gt" value: 2500 - field: "energy_state.battery.state_of_charge" diff --git a/data/examples/rules/stop/home_load_stop_rules.yaml b/data/examples/rules/stop/home_load_stop_rules.yaml new file mode 100644 index 0000000..62a20d5 --- /dev/null +++ b/data/examples/rules/stop/home_load_stop_rules.yaml @@ -0,0 +1,62 @@ +rules: + - name: "High Home Load Forecast" + description: "Stop mining when total household forecast is high for the next 2 hours" + conditions: + all_of: + - field: "home_load.total_forecast.next_2h.avg_power" + operator: "gt" + value: 2500 # Household expected to draw over 2.5 kW on average + - field: "energy_state.battery.state_of_charge" + operator: "lt" + value: 70 + priority: 60 + enabled: true + + - name: "Boiler Peak + Low Battery" + description: "Stop mining when boiler is forecasted to peak and battery is not high" + conditions: + all_of: + - field: "home_load.devices.boiler.forecast.next_1h.peak_power" + operator: "gt" + value: 2000 # Boiler expected to draw over 2 kW peak in next hour + - field: "energy_state.battery.state_of_charge" + operator: "lt" + value: 60 + priority: 55 + enabled: true + + - name: "Evening High Load History" + description: "Stop mining during evening hours when recent consumption is high" + conditions: + all_of: + - field: "timestamp.hour" + operator: "in" + value: [17, 18, 19, 20, 21, 22] + - field: "home_load.total_history.last_1h.avg_power" + operator: "gt" + value: 2000 # Currently consuming over 2 kW average + - field: "energy_state.battery.state_of_charge" + operator: "lt" + value: 75 + priority: 50 + enabled: true + + - name: "Multiple Devices High Forecast" + description: "Stop mining when total forecast for next 4 hours shows sustained high load" + conditions: + all_of: + - field: "home_load.total_forecast.next_4h.avg_power" + operator: "gt" + value: 1800 + - field: "home_load.total_forecast.next_4h.peak_power" + operator: "gt" + value: 3000 # Peak expected above 3 kW + - any_of: + - field: "energy_state.battery.state_of_charge" + operator: "lt" + value: 50 + - field: "energy_state.production" + operator: "lt" + value: 1000 + priority: 45 + enabled: true diff --git a/data/policies/550e8400-e29b-41d4-a716-446655440000.yaml b/data/policies/550e8400-e29b-41d4-a716-446655440000.yaml index a1f1de0..bec94a5 100644 --- a/data/policies/550e8400-e29b-41d4-a716-446655440000.yaml +++ b/data/policies/550e8400-e29b-41d4-a716-446655440000.yaml @@ -119,7 +119,7 @@ stop_rules: description: Stop mining during high home energy demand periods conditions: all_of: - - field: home_load_forecast + - field: home_load.total_forecast.next_2h.avg_power operator: gt value: 2800 - field: timestamp.hour @@ -136,6 +136,6 @@ stop_rules: enabled: true metadata: author: Edge Mining User - version: 5 + version: 6 created: '2025-08-04' - last_modified: '2026-01-18' + last_modified: '2026-04-22' diff --git a/docs/home_load/home_load_forecast_providers.md b/docs/home_load/home_load_forecast_providers.md new file mode 100644 index 0000000..eb3e101 --- /dev/null +++ b/docs/home_load/home_load_forecast_providers.md @@ -0,0 +1,433 @@ +# Home Load Forecast Providers + +This document describes all available **Energy Load Forecast Providers** in +EdgeMining. Each provider implements `EnergyLoadForecastProviderPort` and +produces a `LoadEnergyConsumption` forecast for a configurable time horizon. + +Providers are selected per-device via the `EnergyLoadForecastProviderAdapter` +enum and configured through a corresponding dataclass. + +--- + +## Provider Summary + +| Adapter Enum | Provider Class | Category | Dependencies | Pre-trained Model | +|---|---|---|---|---| +| `DUMMY` | `DummyEnergyLoadForecastProvider` | Testing | None | No | +| `NAIVE_LAST_HOUR` | `NaiveLastHourForecastProvider` | Baseline | None | No | +| `NAIVE_PERSISTENCE` | `NaivePersistenceForecastProvider` | Baseline | None | No | +| `SEASONAL_BASELINE` | `SeasonalBaselineForecastProvider` | Statistical | None | No | +| `TYPICAL_PROFILE` | `TypicalProfileForecastProvider` | Statistical | None | No | +| `STATSMODELS` | `StatsmodelsForecastProvider` | ML | `statsmodels` | Yes | +| `XGBOOST` | `XGBoostForecastProvider` | ML | `xgboost` | Yes | +| `SKFORECAST` | `SkforecastForecastProvider` | ML | `skforecast`, `scikit-learn` | Yes | + +--- + +## Baseline Providers + +### DUMMY + +**Purpose**: development and testing only. Generates random power values so +that the rest of the pipeline can run without real sensor data. + +**Algorithm**: if history is available, takes the last interval's average power +as baseline; otherwise picks a random value in `[200, load_power_max]`. Each +forecast hour applies small random noise. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `load_power_max` | `float` | `500.0` | Upper bound for generated power (W) | + +| Property | Value | +|---|---| +| Min required history | 0 hours | +| Forecast horizon | N/A (config-driven) | +| File | `adapters/domain/home_load/forecast_providers/dummy.py` | + +--- + +### NAIVE_LAST_HOUR + +**Purpose**: simplest real-world baseline. Repeats the most recent measured +power into the future — useful as a short-horizon fallback when no other +provider is available. + +**Algorithm**: computes the average power over the last 1 hour of history and +projects that flat value for every forecast hour. Falls back to the overall +history average if the last hour has no data. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `hours_ahead` | `int` | `3` | Forecast horizon in hours | + +| Property | Value | +|---|---| +| Min required history | 1 hour | +| Best for | Very short horizons (1–3 h), instant fallback | +| File | `adapters/domain/home_load/forecast_providers/naive_last_hour.py` | + +--- + +### NAIVE_PERSISTENCE + +**Purpose**: strong intra-day baseline that assumes tomorrow looks like +yesterday. + +**Algorithm**: builds an `hour → power` map from the same calendar date +`delta_days` ago, then replays that 24 h profile forward. Missing hours fall +back to the global history average. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `hours_ahead` | `int` | `24` | Forecast horizon in hours | +| `delta_days` | `int` | `1` | How many days back to look (1 = yesterday) | + +| Property | Value | +|---|---| +| Min required history | `delta_days × 24` hours (default 24) | +| Best for | Devices with regular daily patterns; ML-free fallback | +| File | `adapters/domain/home_load/forecast_providers/naive_persistence.py` | + +--- + +## Statistical Providers + +### SEASONAL_BASELINE + +**Purpose**: lightweight statistical forecast that captures weekly +seasonality without any ML dependency. + +**Algorithm**: groups all history by `(day_of_week, hour_of_day)` and averages +each slot. For each forecast hour, looks up the matching `(dow, hod)` bucket. +Falls back to the global average across all slots. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `hours_ahead` | `int` | `3` | Forecast horizon in hours | +| `weeks_lookback` | `int` | `4` | Weeks of history to consider | + +| Property | Value | +|---|---| +| Min required history | 0 hours (degrades gracefully) | +| Best for | Quick start with ≥1 week of data | +| File | `adapters/domain/home_load/forecast_providers/seasonal_baseline.py` | + +--- + +### TYPICAL_PROFILE + +**Purpose**: more refined statistical forecast that adds **monthly** grouping +on top of weekly seasonality. Captures how consumption changes across seasons +(e.g. heating in winter vs. cooling in summer). + +**Algorithm**: two-level profile lookup: +1. **Primary**: `(month, day_of_week, hour_of_day)` — average power for this + exact month + weekday + hour combination. +2. **Fallback**: `(day_of_week, hour_of_day)` — ignores month, same as + `SEASONAL_BASELINE` logic. +3. **Global**: overall average if both levels miss. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `hours_ahead` | `int` | `24` | Forecast horizon in hours | +| `weeks_lookback` | `int` | `8` | Weeks of history to consider | + +| Property | Value | +|---|---| +| Min required history | `weeks_lookback × 168` hours (default 1 344 h ≈ 8 weeks) | +| Best for | Devices with seasonal patterns; new installations with ≥2 months of data | +| File | `adapters/domain/home_load/forecast_providers/typical_profile.py` | + +--- + +## ML Providers + +All ML providers share these traits: + +- **Lazy imports**: heavy dependencies (`statsmodels`, `xgboost`, `skforecast`) + are imported at runtime. If a library is missing, the provider gracefully + returns `None` (except `STATSMODELS` which raises). +- **Pre-trained model support**: each looks for an active `LoadConsumptionModel` + in the model repository. If found, the serialised model is loaded via + `pickle`. Otherwise, the provider fits on-the-fly from history. +- **Nightly training**: `LoadForecastModelTrainingService.train_all()` trains + all ML providers (HW, XGBoost, skforecast), evaluates on a 24 h holdout, + and promotes the best model (lowest MAE) to active. + +### STATSMODELS + +**Purpose**: Holt-Winters exponential smoothing — a classical time-series +method that captures trend and daily seasonality (period = 24 h). + +**Algorithm**: `ExponentialSmoothing(trend="add", seasonal="add", +seasonal_periods=24)` from `statsmodels`. Fits on hourly power series derived +from history intervals. Forecast calls `fitted.forecast(hours_ahead)`. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `hours_ahead` | `int` | `3` | Forecast horizon in hours | +| `weeks_lookback` | `int` | `8` | Weeks of history for training | +| `method` | `str` | `"hw"` | Model family (`"hw"` = Holt-Winters; `"sarima"` reserved) | +| `seasonal_periods` | `int` | `24` | Seasonal cycle length in hours | + +| Property | Value | +|---|---| +| Min required history | `seasonal_periods × 2` hours (default 48) | +| Best for | Smooth loads with clear 24 h seasonality (e.g. household aggregate) | +| File | `adapters/domain/home_load/forecast_providers/statsmodels_hw.py` | + +--- + +### XGBOOST + +**Purpose**: gradient-boosted trees using hand-crafted calendar + lag features +with iterative 1-step-ahead prediction. + +**Algorithm**: trains an `XGBRegressor` on a supervised dataset built from: +- **Calendar features**: `hour_of_day`, `day_of_week`, `is_weekend`, `month`. +- **Lag features**: power at `t-1h`, `t-2h`, `t-3h`, `t-24h`, `t-168h`. + +Prediction iterates 1 step at a time, appending the previous prediction as +the next lag input. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `hours_ahead` | `int` | `3` | Forecast horizon in hours | +| `weeks_lookback` | `int` | `8` | Weeks of history for training | +| `n_estimators` | `int` | `100` | Number of boosting rounds | +| `max_depth` | `int` | `6` | Maximum tree depth | +| `learning_rate` | `float` | `0.1` | Boosting learning rate | + +| Property | Value | +|---|---| +| Min required history | `168 + 48 + hours_ahead` hours (default 219) | +| Best for | Non-linear patterns, devices with strong weekly periodicity | +| File | `adapters/domain/home_load/forecast_providers/xgboost_provider.py` | + +--- + +### SKFORECAST + +**Purpose**: auto-regressive multi-step forecasting via `skforecast`'s +`ForecasterRecursive`, wrapping **any** scikit-learn regressor. The forecaster +feeds its own predictions back as input for subsequent steps, producing native +multi-step forecasts without manual lag iteration. + +**Algorithm**: `ForecasterRecursive(estimator=, lags=num_lags)` +fits on hourly power series. Prediction calls `forecaster.predict(steps=N)`. + +**Supported sklearn backends** (selected via `sklearn_model` config string): + +| Backend | Strengths | Best for | +|---|---|---| +| `RandomForestRegressor` | Robust to outlier, feature importance | Medium-large datasets | +| `GradientBoostingRegressor` | High accuracy, handles non-linearity | Production | +| `ExtraTreesRegressor` | Fast training, good trade-off | Quick screening | +| `KNeighborsRegressor` | No heavy training, adaptive | Regular profiles | +| `Ridge` | Interpretable, very fast | Linear relationships | +| `Lasso` | Sparse features, fast | Feature selection | +| `ElasticNet` | Mix of Ridge + Lasso | Balanced regularisation | +| `AdaBoostRegressor` | Adaptive boosting | Bias reduction | +| `MLPRegressor` | Captures complex patterns | Large datasets | +| `SVR` | Good on small datasets | Low-data scenarios | + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `hours_ahead` | `int` | `24` | Forecast horizon in hours | +| `weeks_lookback` | `int` | `8` | Weeks of history for training | +| `sklearn_model` | `str` | `"RandomForestRegressor"` | Name of the sklearn regressor class | +| `num_lags` | `int` | `72` | Number of lag observations used as features | + +| Property | Value | +|---|---| +| Min required history | `num_lags + 48 + hours_ahead` hours (default 144) | +| Best for | General-purpose ML forecasting with model competition | +| File | `adapters/domain/home_load/forecast_providers/skforecast_provider.py` | + +#### Optuna Bayesian Tuning (F6) + +The `SkforecastForecastProvider.tune()` static method runs Bayesian +hyperparameter optimisation via `optuna` + `bayesian_search_forecaster`. +It searches: + +- **Lag count**: categorical over `[24, 48, 72]` +- **Model hyperparameters**: per-model search spaces (e.g. `n_estimators`, + `max_depth`, `learning_rate`, `alpha`, `n_neighbors`) + +The training service calls `tune()` automatically during nightly training +(configurable via `perform_tuning` / `tuning_trials` parameters). Best +parameters are stored in `LoadConsumptionModel.tuning_params`. + +#### Rolling-Window Backtesting (F7) + +The `SkforecastForecastProvider.backtest()` static method evaluates a fitted +forecaster on the full training set using `backtesting_forecaster` with +`TimeSeriesFold`. Returns: + +- `backtest_mae` — MAE across all folds +- `backtest_rmse` — RMSE across all folds +- `backtest_folds` — number of evaluation windows + +Backtesting runs automatically after training. Results are stored on the +`LoadConsumptionModel` entity alongside the holdout metrics. + +--- + +## Choosing a Provider + +``` +Is this for development/testing? + └─ Yes → DUMMY + +Do you have < 1 hour of history? + └─ Yes → NAIVE_LAST_HOUR (flat repeat of last reading) + +Do you have ~ 1 day of history? + └─ Yes → NAIVE_PERSISTENCE (yesterday's profile) + +Do you have 1–4 weeks of history? + └─ Yes → SEASONAL_BASELINE (weekly pattern average) + +Do you have 2+ months of history? + └─ Yes → TYPICAL_PROFILE (monthly + weekly pattern) + +Do you have 1+ week and want ML? + └─ Yes → STATSMODELS (Holt-Winters) or XGBOOST + +Do you have 1+ week, want best accuracy, and can install skforecast? + └─ Yes → SKFORECAST (auto-regressive multi-model with tuning) +``` + +In production, **SKFORECAST** is recommended as the primary provider. The +nightly training service automatically competes Holt-Winters, XGBoost, and +skforecast models, promoting the best one. The simpler providers +(`NAIVE_PERSISTENCE`, `SEASONAL_BASELINE`) serve as robust fallbacks. + +--- + +## Architecture + +### Port & Adapter Pattern + +``` +Domain Adapters +┌──────────────────┐ ┌───────────────────────────┐ +│ EnergyLoadFore- │ │ DummyProvider │ +│ castProviderPort │◄──────────│ NaiveLastHourProvider │ +│ │ │ NaivePersistenceProvider │ +│ + adapter_type │ │ SeasonalBaselineProvider │ +│ + min_history │ │ TypicalProfileProvider │ +│ + get_forecast()│ │ StatsmodelsProvider │ +│ │ │ XGBoostProvider │ +│ │ │ SkforecastProvider │ +└──────────────────┘ └───────────────────────────┘ +``` + +Each provider has: + +1. **Enum value** in `EnergyLoadForecastProviderAdapter` (domain layer) +2. **Config dataclass** in `shared/adapter_configs/home_load.py` +3. **Factory class** implementing `EnergyLoadForecastAdapterFactory` +4. **Schema class** in `adapters/domain/home_load/schemas.py` +5. **Wiring** in `AdapterService` factory dispatch + `adapter_maps` + +### Shared Feature Engineering + +ML providers share helper functions from +`adapters/domain/home_load/forecast_providers/features.py`: + +| Function | Used by | Description | +|---|---|---| +| `intervals_to_hourly_series()` | HW, XGB, Skforecast | Converts `LoadEnergyConsumption` intervals to `[(timestamp, power)]` | +| `fill_missing_hours()` | HW, XGB, Skforecast | Forward-fills gaps in hourly series | +| `build_calendar_features()` | XGB | Extracts `[hour, dow, is_weekend, month]` | +| `build_lag_features()` | XGB | Creates lag columns `[1h, 2h, 3h, 24h, 168h]` | +| `prepare_supervised_dataset()` | XGB | Combines calendar + lag features into `(X, y)` | + +### Model Lifecycle + +``` +Nightly Training (04:00) + │ + ├─ For each enabled device: + │ ├─ Fetch 8 weeks of history + │ ├─ Split: train (all - 24h) / holdout (last 24h) + │ │ + │ ├─ _train_hw() → LoadConsumptionModel (STATSMODELS) + │ ├─ _train_xgb() → LoadConsumptionModel (XGBOOST) + │ └─ _train_skforecast() → LoadConsumptionModel (SKFORECAST) + │ ├─ Optuna tuning (optional, default ON) + │ └─ Rolling-window backtesting + │ + │ Compare MAE → promote best → is_active = True + │ Persist all models to LoadConsumptionModelRepository + │ + └─ Done + +Forecast (every 5s optimisation loop) + │ + ├─ Provider checks model_repo for active model + ├─ If found → pickle.loads() → predict + └─ If not → fit on-the-fly from history → predict +``` + +--- + +## LoadConsumptionModel Entity + +The `LoadConsumptionModel` entity stores trained model metadata and weights: + +| Field | Type | Description | +|---|---|---| +| `device_id` | `Optional[EntityId]` | Device this model was trained for (`None` = aggregate) | +| `adapter_type` | `EnergyLoadForecastProviderAdapter` | Which ML provider created it | +| `trained_at` | `Optional[datetime]` | Training timestamp | +| `mae` | `Optional[float]` | Mean Absolute Error on 24 h holdout | +| `rmse` | `Optional[float]` | Root Mean Squared Error on 24 h holdout | +| `samples_used` | `int` | Number of training data points | +| `is_active` | `bool` | Whether this model is the current production model | +| `model_bytes` | `Optional[bytes]` | Serialised model (pickle) | +| `tuning_params` | `Optional[dict]` | Best hyperparameters from Optuna tuning | +| `backtest_mae` | `Optional[float]` | MAE from rolling-window backtesting | +| `backtest_rmse` | `Optional[float]` | RMSE from rolling-window backtesting | +| `backtest_folds` | `int` | Number of backtesting evaluation folds | + +### ML Model Competition and Promotion + +During each training run (nightly at 04:00 or triggered manually via API), +the service trains **three candidate models** for every enabled device: +Holt-Winters (STATSMODELS), XGBoost, and Skforecast. Each candidate is +evaluated against a **holdout set** consisting of the last 24 hours of +history data. + +**Selection criterion**: the candidate with the **lowest MAE** (Mean Absolute +Error) on the holdout wins and is promoted to `is_active = True`. All +previously active models for that device are demoted to `is_active = False`. + +```python +candidates = [hw_model, xgb_model, skf_model] +best = min(candidates, key=lambda m: m.mae) +best.is_active = True +``` + +**What `is_active` means in practice**: + +| `is_active` | Meaning | +|---|---| +| `True` | This is the **production model** — the forecast provider will load and use it for predictions | +| `False` | Historical/archived model — kept for audit and comparison but not used for live forecasts | + +**How providers consume the active model**: when a forecast provider +(STATSMODELS, XGBOOST, or SKFORECAST) needs to produce a prediction, it +queries `LoadConsumptionModelRepository.get_active_model(adapter_type, +device_id)`. If an active model exists, it is deserialised from +`model_bytes` via `pickle.loads()` and used directly. If no active model is +found (e.g. training has never run), the provider falls back to +**fit-on-the-fly** from the supplied history — slower but ensures a forecast +is always available. + +**Other stored metrics** (`rmse`, `backtest_mae`, `backtest_rmse`, +`backtest_folds`, `tuning_params`) are **informational only** — they are not +used for model selection. They are persisted for monitoring, comparison, and +debugging via the `GET /training/models` API endpoint. diff --git a/edge_mining/__main__.py b/edge_mining/__main__.py index 485a669..ee41418 100644 --- a/edge_mining/__main__.py +++ b/edge_mining/__main__.py @@ -73,6 +73,8 @@ async def main_async(): optimization_service=services.optimization_service, logger=logger, settings=settings, + home_load_history_service=services.home_load_history_service, + load_forecast_training_service=services.load_forecast_training_service, ) await asyncio.gather( diff --git a/edge_mining/adapters/domain/home_load/fast_api/__init__.py b/edge_mining/adapters/domain/home_load/fast_api/__init__.py new file mode 100644 index 0000000..ab53578 --- /dev/null +++ b/edge_mining/adapters/domain/home_load/fast_api/__init__.py @@ -0,0 +1 @@ +"""Adapter that uses FastAPI infrastructure for Home load domain API""" diff --git a/edge_mining/adapters/domain/home_load/fast_api/router.py b/edge_mining/adapters/domain/home_load/fast_api/router.py new file mode 100644 index 0000000..ddb599a --- /dev/null +++ b/edge_mining/adapters/domain/home_load/fast_api/router.py @@ -0,0 +1,969 @@ +"""API Router for home load domain.""" + +import uuid +from datetime import datetime, timedelta, timezone +from typing import Annotated, Any, Dict, List, Optional, cast + +from fastapi import APIRouter, Depends, HTTPException, Query + +from edge_mining.adapters.domain.home_load.history_providers.helpers import group_power_points_into_intervals +from edge_mining.adapters.domain.home_load.schemas import ( + ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_SCHEMA_MAP, + ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_SCHEMA_MAP, + EnergyLoadForecastProviderCreateSchema, + EnergyLoadForecastProviderSchema, + EnergyLoadForecastProviderUpdateSchema, + EnergyLoadHistoryProviderCreateSchema, + EnergyLoadHistoryProviderSchema, + EnergyLoadHistoryProviderUpdateSchema, + HomeLoadPowerPointSchema, + HomeLoadsProfileSchema, + LoadConsumptionModelSchema, + LoadDeviceCreateSchema, + LoadDeviceSchema, + LoadDeviceUpdateSchema, + LoadEnergyConsumptionSchema, +) + +# Import dependency injection setup functions +from edge_mining.adapters.infrastructure.api.setup import ( + get_adapter_service, + get_config_service, + get_home_load_history_service, + get_load_forecast_training_service, +) +from edge_mining.application.interfaces import ( + AdapterServiceInterface, + ConfigurationServiceInterface, + HomeLoadHistoryServiceInterface, + LoadForecastTrainingServiceInterface, +) +from edge_mining.domain.common import EntityId, Timestamp +from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile +from edge_mining.domain.home_load.common import ( + EnergyLoadForecastProviderAdapter, + EnergyLoadHistoryProviderAdapter, + LoadDeviceCategory, +) +from edge_mining.domain.home_load.entities import EnergyLoadForecastProvider, EnergyLoadHistoryProvider, LoadDevice +from edge_mining.domain.home_load.exceptions import ( + EnergyLoadForecastProviderAlreadyExistsError, + EnergyLoadForecastProviderConfigurationError, + EnergyLoadForecastProviderError, + EnergyLoadForecastProviderNotFoundError, + EnergyLoadHistoryProviderAlreadyExistsError, + EnergyLoadHistoryProviderConfigurationError, + EnergyLoadHistoryProviderNotFoundError, + HomeLoadsProfileAddDeviceError, + HomeLoadsProfileAlreadyExistsError, + HomeLoadsProfileDeviceNotFoundError, + HomeLoadsProfileNotFoundError, + HomeLoadsProfileRemoveDeviceError, +) +from edge_mining.domain.home_load.value_objects import LoadEnergyConsumption +from edge_mining.shared.adapter_maps.home_load import ( + ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP, + ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP, +) +from edge_mining.shared.external_services.common import ExternalServiceAdapter +from edge_mining.shared.interfaces.config import EnergyLoadForecastProviderConfig, EnergyLoadHistoryProviderConfig + +router = APIRouter() + + +# Home Loads Profile endpoints +@router.get("/home-loads-profiles", response_model=List[HomeLoadsProfileSchema]) +async def get_home_loads_profiles_list( + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> List[HomeLoadsProfileSchema]: + """Get a list of all home loads profiles.""" + try: + profiles: List[HomeLoadsProfile] = config_service.list_home_loads_profiles() + + # Convert to home loads profile schema + profile_schemas: List[HomeLoadsProfileSchema] = [] + + for profile in profiles: + profile_schemas.append(HomeLoadsProfileSchema.from_model(profile)) + + return profile_schemas + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post("/home-loads-profiles", response_model=HomeLoadsProfileSchema) +async def add_home_loads_profile( + profile_name: str, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> HomeLoadsProfileSchema: + """Add a new home loads profile.""" + try: + # Add the profile + added_profile = config_service.add_home_loads_profile(profile_name) + + # For now, return the created profile + return HomeLoadsProfileSchema.from_model(added_profile) + except HomeLoadsProfileAlreadyExistsError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get("/home-loads-profiles/{profile_id}", response_model=HomeLoadsProfileSchema) +async def get_home_loads_profile( + profile_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> HomeLoadsProfileSchema: + """Get details of a specific home loads profile.""" + try: + profile = config_service.get_home_loads_profile(profile_id) + + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + return HomeLoadsProfileSchema.from_model(profile) + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.put("/home-loads-profiles/{profile_id}", response_model=HomeLoadsProfileSchema) +async def update_home_loads_profile( + profile_id: EntityId, + profile_new_name: str, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> HomeLoadsProfileSchema: + """Update an existing home loads profile.""" + try: + profile = config_service.update_home_loads_profile(profile_id, profile_new_name) + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + response = HomeLoadsProfileSchema.from_model(profile) + return response + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.delete("/home-loads-profiles/{profile_id}", response_model=HomeLoadsProfileSchema) +async def delete_home_loads_profile( + profile_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> HomeLoadsProfileSchema: + """Remove a home loads profile.""" + try: + deleted_profile = config_service.remove_home_loads_profile(profile_id) + response = HomeLoadsProfileSchema.from_model(deleted_profile) + return response + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +# Load Device endpoints +@router.get("/home-loads-profiles/{profile_id}/devices", response_model=List[LoadDeviceSchema]) +async def get_load_devices_list( + profile_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> List[LoadDeviceSchema]: + """Get a list of all load devices in a profile.""" + try: + profile = config_service.get_home_loads_profile(profile_id) + + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + + devices: List[LoadDeviceSchema] = [] + for device in profile.devices: + devices.append(LoadDeviceSchema.from_model(device)) + + return devices + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post("/home-loads-profiles/{profile_id}/devices", response_model=LoadDeviceSchema) +async def add_load_device( + profile_id: EntityId, + device_data: LoadDeviceCreateSchema, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> LoadDeviceSchema: + """Add a new load device to a profile.""" + try: + # Convert to domain model + device_to_add: LoadDevice = device_data.to_model() + + added_device = config_service.add_load_device_to_profile(profile_id=profile_id, load_device=device_to_add) + + if added_device is None: + raise HomeLoadsProfileAddDeviceError(f"Failed to add load device to profile {profile_id}") + + return LoadDeviceSchema.from_model(added_device) + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileAddDeviceError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get("/home-loads-profiles/{profile_id}/devices/{device_id}", response_model=LoadDeviceSchema) +async def get_load_device( + profile_id: EntityId, + device_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> LoadDeviceSchema: + """Get details of a specific load device.""" + try: + profile = config_service.get_home_loads_profile(profile_id) + + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + + # Find the specific device + device = next((d for d in profile.devices if d.id == device_id), None) + + if device is None: + raise HomeLoadsProfileDeviceNotFoundError( + f"Load Device with ID {device_id} not found in Home Loads Profile {profile_id}" + ) + return LoadDeviceSchema.from_model(device) + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileDeviceNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.put("/home-loads-profiles/{profile_id}/devices/{device_id}", response_model=LoadDeviceSchema) +async def update_load_device( + profile_id: EntityId, + device_id: EntityId, + device_update: LoadDeviceUpdateSchema, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> LoadDeviceSchema: + """Update an existing load device.""" + try: + profile = config_service.get_home_loads_profile(profile_id) + + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + + # Find the specific device + device = next((d for d in profile.devices if d.id == device_id), None) + + if device is None: + raise HomeLoadsProfileDeviceNotFoundError( + f"Load Device with ID {device_id} not found in Home Loads Profile {profile_id}" + ) + + # Remove the old device + deleted_device = config_service.remove_load_device_from_profile(profile_id, device_id) + + if deleted_device is None: + raise HomeLoadsProfileRemoveDeviceError( + f"Failed to remove existing load device with ID {device_id} from profile {profile_id}" + ) + + # Add the updated device + forecast_provider_id = ( + EntityId(uuid.UUID(device_update.energy_load_forecast_provider_id)) + if device_update.energy_load_forecast_provider_id + else device.energy_load_forecast_provider_id + ) + history_provider_id = ( + EntityId(uuid.UUID(device_update.energy_load_history_provider_id)) + if device_update.energy_load_history_provider_id + else device.energy_load_history_provider_id + ) + category = ( + LoadDeviceCategory(device_update.category) + if isinstance(device_update.category, str) + else device_update.category + ) + new_device = LoadDevice( + id=device.id, + name=device_update.name or device.name, + category=category, + enabled=device_update.enabled, + energy_load_forecast_provider_id=forecast_provider_id, + energy_load_history_provider_id=history_provider_id, + ) + + device_added = config_service.add_load_device_to_profile( + profile_id=profile_id, + load_device=new_device, + ) + + if device_added is None: + raise HomeLoadsProfileAddDeviceError(f"Failed to add updated load device to profile {profile_id}") + + return LoadDeviceSchema.from_model(device_added) + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileDeviceNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileRemoveDeviceError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except HomeLoadsProfileAddDeviceError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.delete("/home-loads-profiles/{profile_id}/devices/{device_id}", response_model=LoadDeviceSchema) +async def delete_load_device( + profile_id: EntityId, + device_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> LoadDeviceSchema: + """Remove a load device from a profile.""" + try: + delete_load_device = config_service.remove_load_device_from_profile(profile_id, device_id) + + if delete_load_device is None: + raise HomeLoadsProfileRemoveDeviceError( + f"Failed to remove load device with ID {device_id} from profile {profile_id}" + ) + response = LoadDeviceSchema.from_model(delete_load_device) + return response + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileRemoveDeviceError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +# Energy Load Forecast Provider endpoints +@router.get("/energy-load-forecast-providers", response_model=List[EnergyLoadForecastProviderSchema]) +async def get_energy_load_forecast_providers_list( + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> List[EnergyLoadForecastProviderSchema]: + """Get a list of all energy load forecast providers.""" + try: + providers = config_service.list_energy_load_forecast_providers() + return [EnergyLoadForecastProviderSchema.from_model(p) for p in providers] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post("/energy-load-forecast-providers", response_model=EnergyLoadForecastProviderSchema) +async def add_energy_load_forecast_provider( + provider_data: EnergyLoadForecastProviderCreateSchema, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> EnergyLoadForecastProviderSchema: + """Add a new energy load forecast provider.""" + try: + provider_to_add: EnergyLoadForecastProvider = provider_data.to_model() + + if provider_to_add.config is None: + raise EnergyLoadForecastProviderConfigurationError( + "Energy Load Forecast provider configuration should be set" + ) + + added = config_service.add_energy_load_forecast_provider(provider_to_add) + return EnergyLoadForecastProviderSchema.from_model(added) + + except EnergyLoadForecastProviderAlreadyExistsError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except EnergyLoadForecastProviderConfigurationError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get("/energy-load-forecast-providers/types", response_model=List[EnergyLoadForecastProviderAdapter]) +async def get_energy_load_forecast_provider_types() -> List[EnergyLoadForecastProviderAdapter]: + """Get a list of available energy load forecast provider types.""" + try: + return [EnergyLoadForecastProviderAdapter(adapter.value) for adapter in EnergyLoadForecastProviderAdapter] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/energy-load-forecast-providers/types/{adapter_type}/external-services", + response_model=Optional[ExternalServiceAdapter], +) +async def get_energy_load_forecast_provider_type_external_service_types( + adapter_type: EnergyLoadForecastProviderAdapter, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> Optional[ExternalServiceAdapter]: + """Get the compatible external service type for a specific energy load forecast provider type.""" + try: + return config_service.get_energy_load_forecast_provider_external_service_adapter(adapter_type) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/energy-load-forecast-providers/types/{adapter_type}/config-schema", + response_model=Dict[str, Any], +) +async def get_energy_load_forecast_provider_config_schema( + adapter_type: EnergyLoadForecastProviderAdapter, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> Dict[str, Any]: + """Get the configuration schema for a specific energy load forecast provider type.""" + try: + try: + provider_adapter = EnergyLoadForecastProviderAdapter(adapter_type.value) + except ValueError as e: + raise ValueError(f"Invalid energy load forecast provider adapter type: {adapter_type}") from e + + # Get the corresponding configuration class for the adapter type + provider_config_type: Optional[type[EnergyLoadForecastProviderConfig]] = ( + ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get(provider_adapter) + ) + + if provider_config_type is None: + raise EnergyLoadForecastProviderConfigurationError( + f"No configuration class found for adapter type {adapter_type}" + ) + + # Get the corresponding schema class + schema_class = ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_SCHEMA_MAP.get(provider_config_type) + if schema_class is None: + raise EnergyLoadForecastProviderConfigurationError( + f"No schema found for configuration class {provider_config_type}" + ) + + # Return the JSON schema + return schema_class.model_json_schema() + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get("/energy-load-forecast-providers/{provider_id}", response_model=EnergyLoadForecastProviderSchema) +async def get_energy_load_forecast_provider( + provider_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> EnergyLoadForecastProviderSchema: + """Get details of a specific energy load forecast provider.""" + try: + provider = config_service.get_energy_load_forecast_provider(provider_id) + if provider is None: + raise EnergyLoadForecastProviderNotFoundError( + f"Energy Load Forecast Provider with ID {provider_id} not found" + ) + return EnergyLoadForecastProviderSchema.from_model(provider) + except EnergyLoadForecastProviderNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.put("/energy-load-forecast-providers/{provider_id}", response_model=EnergyLoadForecastProviderSchema) +async def update_energy_load_forecast_provider( + provider_id: EntityId, + provider_update: EnergyLoadForecastProviderUpdateSchema, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> EnergyLoadForecastProviderSchema: + """Update an existing energy load forecast provider.""" + try: + existing = config_service.get_energy_load_forecast_provider(provider_id) + if existing is None: + raise EnergyLoadForecastProviderNotFoundError( + f"Energy Load Forecast Provider with ID {provider_id} not found" + ) + existing.name = provider_update.name or existing.name + if provider_update.config is not None and existing.adapter_type: + config_type = ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get(existing.adapter_type) + if config_type: + existing.config = cast(EnergyLoadForecastProviderConfig, config_type.from_dict(provider_update.config)) + if provider_update.external_service_id is not None: + existing.external_service_id = EntityId(uuid.UUID(provider_update.external_service_id)) + updated = config_service.update_energy_load_forecast_provider(existing) + return EnergyLoadForecastProviderSchema.from_model(updated) + except EnergyLoadForecastProviderNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.delete("/energy-load-forecast-providers/{provider_id}", response_model=EnergyLoadForecastProviderSchema) +async def delete_energy_load_forecast_provider( + provider_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> EnergyLoadForecastProviderSchema: + """Remove an energy load forecast provider.""" + try: + removed = config_service.remove_energy_load_forecast_provider(provider_id) + return EnergyLoadForecastProviderSchema.from_model(removed) + except EnergyLoadForecastProviderNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +# --- Energy Load History Provider endpoints --- + + +@router.get("/energy-load-history-providers", response_model=List[EnergyLoadHistoryProviderSchema]) +async def get_energy_load_history_providers_list( + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> List[EnergyLoadHistoryProviderSchema]: + """Get a list of all energy load history providers.""" + try: + providers = config_service.list_energy_load_history_providers() + return [EnergyLoadHistoryProviderSchema.from_model(p) for p in providers] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post("/energy-load-history-providers", response_model=EnergyLoadHistoryProviderSchema) +async def add_energy_load_history_provider( + provider_data: EnergyLoadHistoryProviderCreateSchema, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> EnergyLoadHistoryProviderSchema: + """Add a new energy load history provider.""" + try: + provider_to_add: EnergyLoadHistoryProvider = provider_data.to_model() + added = config_service.add_energy_load_history_provider(provider_to_add) + return EnergyLoadHistoryProviderSchema.from_model(added) + except EnergyLoadHistoryProviderAlreadyExistsError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except EnergyLoadHistoryProviderConfigurationError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get("/energy-load-history-providers/types", response_model=List[EnergyLoadHistoryProviderAdapter]) +async def get_energy_load_history_provider_types() -> List[EnergyLoadHistoryProviderAdapter]: + """Get a list of available energy load history provider types.""" + try: + return [EnergyLoadHistoryProviderAdapter(adapter.value) for adapter in EnergyLoadHistoryProviderAdapter] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/energy-load-history-providers/types/{adapter_type}/external-services", + response_model=Optional[ExternalServiceAdapter], +) +async def get_energy_load_history_provider_type_external_service_types( + adapter_type: EnergyLoadHistoryProviderAdapter, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> Optional[ExternalServiceAdapter]: + """Get the compatible external service type for a specific energy load history provider type.""" + try: + return config_service.get_energy_load_history_provider_external_service_adapter(adapter_type) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/energy-load-history-providers/types/{adapter_type}/config-schema", + response_model=Dict[str, Any], +) +async def get_energy_load_history_provider_config_schema( + adapter_type: EnergyLoadHistoryProviderAdapter, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> Dict[str, Any]: + """Get the configuration schema for a specific energy load history provider type.""" + try: + try: + provider_adapter = EnergyLoadHistoryProviderAdapter(adapter_type.value) + except ValueError as e: + raise ValueError(f"Invalid energy load history provider adapter type: {adapter_type}") from e + + provider_config_type: Optional[type[EnergyLoadHistoryProviderConfig]] = ( + ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP.get(provider_adapter) + ) + + if provider_config_type is None: + return {} # Some adapters (e.g. DUMMY) have no configuration + + schema_class = ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_SCHEMA_MAP.get(provider_config_type) + if schema_class is None: + raise EnergyLoadHistoryProviderConfigurationError( + f"No schema found for configuration class {provider_config_type}" + ) + + return schema_class.model_json_schema() + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get("/energy-load-history-providers/{provider_id}", response_model=EnergyLoadHistoryProviderSchema) +async def get_energy_load_history_provider( + provider_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> EnergyLoadHistoryProviderSchema: + """Get details of a specific energy load history provider.""" + try: + provider = config_service.get_energy_load_history_provider(provider_id) + if provider is None: + raise EnergyLoadHistoryProviderNotFoundError( + f"Energy Load History Provider with ID {provider_id} not found" + ) + return EnergyLoadHistoryProviderSchema.from_model(provider) + except EnergyLoadHistoryProviderNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.put("/energy-load-history-providers/{provider_id}", response_model=EnergyLoadHistoryProviderSchema) +async def update_energy_load_history_provider( + provider_id: EntityId, + provider_update: EnergyLoadHistoryProviderUpdateSchema, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> EnergyLoadHistoryProviderSchema: + """Update an existing energy load history provider.""" + try: + existing = config_service.get_energy_load_history_provider(provider_id) + if existing is None: + raise EnergyLoadHistoryProviderNotFoundError( + f"Energy Load History Provider with ID {provider_id} not found" + ) + existing.name = provider_update.name or existing.name + if provider_update.config is not None and existing.adapter_type: + config_type = ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP.get(existing.adapter_type) + if config_type: + existing.config = cast(EnergyLoadHistoryProviderConfig, config_type.from_dict(provider_update.config)) + if provider_update.external_service_id is not None: + existing.external_service_id = EntityId(uuid.UUID(provider_update.external_service_id)) + updated = config_service.update_energy_load_history_provider(existing) + return EnergyLoadHistoryProviderSchema.from_model(updated) + except EnergyLoadHistoryProviderNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.delete("/energy-load-history-providers/{provider_id}", response_model=EnergyLoadHistoryProviderSchema) +async def delete_energy_load_history_provider( + provider_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], +) -> EnergyLoadHistoryProviderSchema: + """Remove an energy load history provider.""" + try: + removed = config_service.remove_energy_load_history_provider(provider_id) + return EnergyLoadHistoryProviderSchema.from_model(removed) + except EnergyLoadHistoryProviderNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +# --- Device History Data endpoints --- + + +@router.get( + "/home-loads-profiles/{profile_id}/devices/{device_id}/history", + response_model=List[HomeLoadPowerPointSchema], +) +async def get_device_history( + profile_id: EntityId, + device_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], + history_service: Annotated[HomeLoadHistoryServiceInterface, Depends(get_home_load_history_service)], + start: datetime = Query(..., description="Start of the time window (ISO 8601)"), + end: datetime = Query(..., description="End of the time window (ISO 8601)"), +) -> List[HomeLoadPowerPointSchema]: + """Get historical power points for a specific device within a time window.""" + try: + # Validate that profile and device exist + profile = config_service.get_home_loads_profile(profile_id) + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + + device = next((d for d in profile.devices if d.id == device_id), None) + if device is None: + raise HomeLoadsProfileDeviceNotFoundError( + f"Load Device with ID {device_id} not found in Home Loads Profile {profile_id}" + ) + + points = history_service.get_device_history(device_id, Timestamp(start), Timestamp(end)) + return [HomeLoadPowerPointSchema.from_model(p) for p in points] + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileDeviceNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get( + "/home-loads-profiles/{profile_id}/devices/{device_id}/forecast", + response_model=LoadEnergyConsumptionSchema, +) +async def get_device_forecast( + profile_id: EntityId, + device_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], + adapter_service: Annotated[AdapterServiceInterface, Depends(get_adapter_service)], + history_service: Annotated[HomeLoadHistoryServiceInterface, Depends(get_home_load_history_service)], + hours_ahead: int = Query(default=3, ge=1, le=48, description="Forecast horizon in hours"), + history_hours: int = Query(default=72, ge=1, le=720, description="Hours of history to feed the model"), +) -> LoadEnergyConsumptionSchema: + """Get energy consumption forecast for a specific device.""" + try: + profile = config_service.get_home_loads_profile(profile_id) + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + + device = next((d for d in profile.devices if d.id == device_id), None) + if device is None: + raise HomeLoadsProfileDeviceNotFoundError( + f"Load Device with ID {device_id} not found in Home Loads Profile {profile_id}" + ) + + if not device.energy_load_forecast_provider_id: + raise HTTPException( + status_code=400, + detail=f"Device '{device.name}' has no forecast provider configured.", + ) + + forecast_provider = adapter_service.get_home_load_forecast_provider(device.energy_load_forecast_provider_id) + if forecast_provider is None: + raise HTTPException( + status_code=500, + detail=f"Could not initialize forecast provider for device '{device.name}'.", + ) + + now = Timestamp(datetime.now(timezone.utc)) + history_start = Timestamp(now - timedelta(hours=history_hours)) + power_points = history_service.get_device_history(device_id, history_start, now) + + if not power_points: + raise HTTPException( + status_code=400, + detail=f"No history data available for device '{device.name}'. Collect history first.", + ) + + intervals = group_power_points_into_intervals(power_points) + consumption = LoadEnergyConsumption(timestamp=now, intervals=intervals) + + forecast = forecast_provider.get_consumption_forecast(consumption, hours_ahead=hours_ahead) + if forecast is None: + raise HTTPException( + status_code=500, + detail=f"Forecast provider returned no data for device '{device.name}'.", + ) + + return LoadEnergyConsumptionSchema.from_model(forecast) + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileDeviceNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except EnergyLoadForecastProviderError as e: + min_hours = getattr(forecast_provider, "min_required_history_hours", None) + detail = str(e) + if min_hours: + detail += f" (minimum required: {min_hours} hours)" + raise HTTPException(status_code=400, detail=detail) from e + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post( + "/home-loads-profiles/{profile_id}/devices/{device_id}/history/collect", + response_model=Dict[str, str], +) +async def collect_device_history( + profile_id: EntityId, + device_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], + history_service: Annotated[HomeLoadHistoryServiceInterface, Depends(get_home_load_history_service)], + lookback_hours: int = Query(default=24, ge=1, le=720, description="Hours of history to fetch on first collection"), +) -> Dict[str, str]: + """Fetch power points from the history provider and store them in the database.""" + try: + profile = config_service.get_home_loads_profile(profile_id) + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + + device = next((d for d in profile.devices if d.id == device_id), None) + if device is None: + raise HomeLoadsProfileDeviceNotFoundError( + f"Load Device with ID {device_id} not found in Home Loads Profile {profile_id}" + ) + + if not device.energy_load_history_provider_id: + raise HTTPException( + status_code=400, + detail=f"Device '{device.name}' has no history provider configured.", + ) + + await history_service.collect_devices([device_id], lookback_hours=lookback_hours) + return {"status": "completed", "detail": f"History collection completed for device '{device.name}'."} + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileDeviceNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.delete( + "/home-loads-profiles/{profile_id}/devices/{device_id}/history", + response_model=Dict[str, str], +) +async def delete_device_history( + profile_id: EntityId, + device_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], + history_service: Annotated[HomeLoadHistoryServiceInterface, Depends(get_home_load_history_service)], +) -> Dict[str, str]: + """Delete all stored power points for a specific device.""" + try: + profile = config_service.get_home_loads_profile(profile_id) + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + + device = next((d for d in profile.devices if d.id == device_id), None) + if device is None: + raise HomeLoadsProfileDeviceNotFoundError( + f"Load Device with ID {device_id} not found in Home Loads Profile {profile_id}" + ) + + removed = history_service.clear_device_history(device_id) + return { + "status": "completed", + "detail": f"Deleted {removed} power points for device '{device.name}'.", + } + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileDeviceNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +# --- History Collection endpoints --- + + +@router.post("/history/collect", response_model=Dict[str, str]) +async def trigger_history_collection( + history_service: Annotated[HomeLoadHistoryServiceInterface, Depends(get_home_load_history_service)], + lookback_hours: int = Query(default=24, ge=1, le=720, description="Hours of history to fetch on first collection"), +) -> Dict[str, str]: + """Manually trigger power-point collection for all enabled devices.""" + try: + await history_service.collect_all(lookback_hours=lookback_hours) + return {"status": "completed", "detail": "History collection completed for all eligible devices."} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post("/history/collect/devices", response_model=Dict[str, str]) +async def trigger_history_collection_for_devices( + device_ids: List[str], + history_service: Annotated[HomeLoadHistoryServiceInterface, Depends(get_home_load_history_service)], + lookback_hours: int = Query(default=24, ge=1, le=720, description="Hours of history to fetch on first collection"), +) -> Dict[str, str]: + """Manually trigger power-point collection for specific devices.""" + try: + parsed_ids = [EntityId(uuid.UUID(did)) for did in device_ids] + await history_service.collect_devices(parsed_ids, lookback_hours=lookback_hours) + return { + "status": "completed", + "detail": f"History collection completed for {len(parsed_ids)} device(s).", + } + except ValueError as e: + raise HTTPException(status_code=400, detail=f"Invalid device ID: {e}") from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +# --- Training endpoints --- + + +@router.post("/training/trigger", response_model=Dict[str, str]) +async def trigger_training_all( + training_service: Annotated[LoadForecastTrainingServiceInterface, Depends(get_load_forecast_training_service)], + weeks_lookback: int = Query(default=8, ge=1, le=52, description="Weeks of history to use"), +) -> Dict[str, str]: + """Trigger ML model training for all enabled devices.""" + try: + await training_service.train_all(weeks_lookback=weeks_lookback) + return {"status": "completed", "detail": "Training completed for all eligible devices."} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post( + "/home-loads-profiles/{profile_id}/devices/{device_id}/training/trigger", + response_model=Dict[str, str], +) +async def trigger_training_device( + profile_id: EntityId, + device_id: EntityId, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)], + training_service: Annotated[LoadForecastTrainingServiceInterface, Depends(get_load_forecast_training_service)], + weeks_lookback: int = Query(default=8, ge=1, le=52, description="Weeks of history to use"), +) -> Dict[str, str]: + """Trigger ML model training for a specific device.""" + try: + # Validate that profile and device exist + profile = config_service.get_home_loads_profile(profile_id) + if profile is None: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found") + + device = next((d for d in profile.devices if d.id == device_id), None) + if device is None: + raise HomeLoadsProfileDeviceNotFoundError( + f"Load Device with ID {device_id} not found in Home Loads Profile {profile_id}" + ) + + await training_service.train_device(device_id, weeks_lookback=weeks_lookback) + return {"status": "completed", "detail": f"Training completed for device '{device.name}'."} + except HomeLoadsProfileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except HomeLoadsProfileDeviceNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +# --- Training Models endpoints --- + + +@router.get("/training/models", response_model=List[LoadConsumptionModelSchema]) +async def get_training_models( + training_service: Annotated[LoadForecastTrainingServiceInterface, Depends(get_load_forecast_training_service)], + device_id: Optional[str] = Query(default=None, description="Filter by device UUID"), +) -> List[LoadConsumptionModelSchema]: + """List trained ML models, optionally filtered by device.""" + try: + filter_device_id = EntityId(uuid.UUID(device_id)) if device_id else None + models = training_service.get_models(device_id=filter_device_id) + return [LoadConsumptionModelSchema.from_model(m) for m in models] + except ValueError as e: + raise HTTPException(status_code=400, detail=f"Invalid device_id: {e}") from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.delete("/training/models/{model_id}", status_code=204) +async def delete_training_model( + model_id: str, + training_service: Annotated[LoadForecastTrainingServiceInterface, Depends(get_load_forecast_training_service)], +) -> None: + """Delete a trained ML model by ID.""" + try: + entity_id = EntityId(uuid.UUID(model_id)) + training_service.delete_model(entity_id) + except ValueError as e: + raise HTTPException(status_code=400, detail=f"Invalid model_id: {e}") from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/__init__.py b/edge_mining/adapters/domain/home_load/forecast_providers/__init__.py new file mode 100644 index 0000000..da0cc0b --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/__init__.py @@ -0,0 +1 @@ +"""Collection of home load forecast provider adapters.""" diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/dummy.py b/edge_mining/adapters/domain/home_load/forecast_providers/dummy.py new file mode 100644 index 0000000..2843fe5 --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/dummy.py @@ -0,0 +1,101 @@ +""" +Dummy adapter (Implementation of Port) that simulates +the home loads forecast for Edge Mining Application. +""" + +import random +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.ports import EnergyLoadForecastProviderPort +from edge_mining.domain.home_load.value_objects import HomeLoadEnergyInterval, HomeLoadPowerPoint, LoadEnergyConsumption +from edge_mining.shared.adapter_configs.home_load import EnergyLoadForecastProviderDummyConfig +from edge_mining.shared.external_services.ports import ExternalServicePort +from edge_mining.shared.interfaces.config import Configuration +from edge_mining.shared.interfaces.factories import EnergyLoadForecastAdapterFactory +from edge_mining.shared.logging.port import LoggerPort + + +class DummyEnergyLoadForecastProviderFactory(EnergyLoadForecastAdapterFactory): + """Factory for creating a DummyEnergyLoadForecastProvider instance.""" + + def create( + self, + config: Optional[Configuration], + logger: Optional[LoggerPort], + external_service: Optional[ExternalServicePort], + ) -> "DummyEnergyLoadForecastProvider": + if config is not None and not isinstance(config, EnergyLoadForecastProviderDummyConfig): + raise EnergyLoadForecastProviderError( + "Invalid configuration type for Dummy energy load forecast provider. " + "Expected EnergyLoadForecastProviderDummyConfig." + ) + + load_power_max = 500.0 + if isinstance(config, EnergyLoadForecastProviderDummyConfig): + load_power_max = config.load_power_max + + return DummyEnergyLoadForecastProvider( + load_power_max=load_power_max, + logger=logger, + ) + + +class DummyEnergyLoadForecastProvider(EnergyLoadForecastProviderPort): + """Generates a very basic fake energy load forecast. + + Ignores historical data and emits a random average load per hour bounded + by ``load_power_max``. Useful as a placeholder until an ML/DL forecaster + is wired in. + """ + + def __init__( + self, + load_power_max: float = 500.0, + logger: Optional[LoggerPort] = None, + ): + super().__init__(forecast_provider_type=EnergyLoadForecastProviderAdapter.DUMMY) + self._logger = logger + self.load_power_max = load_power_max + + def get_consumption_forecast( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int = 3 + ) -> Optional[LoadEnergyConsumption]: + """Produce a naive forecast of hourly consumption over ``hours_ahead``.""" + if hours_ahead <= 0: + return None + + now = Timestamp(datetime.now(timezone.utc)) + + if consumption_history.intervals: + # Simple baseline: replay the average of the last observed hour. + baseline_power = consumption_history.intervals[-1].avg_power + else: + baseline_power = Watts(random.uniform(200.0, self.load_power_max)) + + intervals: List[HomeLoadEnergyInterval] = [] + for i in range(hours_ahead): + start = now + timedelta(hours=i) + end = start + timedelta(hours=1) + point = HomeLoadPowerPoint(timestamp=start, power=baseline_power) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[point], + energy=WattHours(float(baseline_power)), + ) + ) + + forecast = LoadEnergyConsumption(timestamp=now, intervals=intervals) + + if self._logger: + self._logger.debug( + f"DummyEnergyLoadForecastProvider: baseline {baseline_power:.0f}W, " + f"{hours_ahead}h ahead, avg_power={forecast.avg_power:.0f}W" + ) + + return forecast diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/features.py b/edge_mining/adapters/domain/home_load/forecast_providers/features.py new file mode 100644 index 0000000..44476ff --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/features.py @@ -0,0 +1,140 @@ +"""Feature engineering utilities for ML-based home load forecast providers. + +Converts LoadEnergyConsumption interval data into structured feature arrays +suitable for scikit-learn / statsmodels / XGBoost models. +""" + +from datetime import datetime, timedelta +from typing import List, Optional, Tuple + +from edge_mining.domain.home_load.value_objects import LoadEnergyConsumption + + +def intervals_to_hourly_series( + consumption: LoadEnergyConsumption, +) -> List[Tuple[datetime, float]]: + """Convert intervals to a sorted list of (timestamp, avg_power) pairs. + + Missing hours are NOT filled — that is the caller's responsibility + (e.g. via ``fill_missing_hours``). + """ + pairs: List[Tuple[datetime, float]] = [] + for interval in consumption.intervals: + pairs.append((interval.start, float(interval.avg_power))) + pairs.sort(key=lambda x: x[0]) + return pairs + + +def fill_missing_hours( + series: List[Tuple[datetime, float]], + start: Optional[datetime] = None, + end: Optional[datetime] = None, + fill_value: float = 0.0, +) -> List[Tuple[datetime, float]]: + """Ensure contiguous hourly coverage by inserting fill_value for missing slots. + + If *start*/*end* are not provided they default to the min/max of the + existing series. + """ + if not series: + return [] + + existing = {ts.replace(minute=0, second=0, microsecond=0): power for ts, power in series} + first = start or min(existing) + last = end or max(existing) + + result: List[Tuple[datetime, float]] = [] + current = first.replace(minute=0, second=0, microsecond=0) + last_rounded = last.replace(minute=0, second=0, microsecond=0) + while current <= last_rounded: + result.append((current, existing.get(current, fill_value))) + current += timedelta(hours=1) + return result + + +def build_calendar_features(timestamps: List[datetime]) -> List[List[float]]: + """Build calendar feature vectors for a list of timestamps. + + Each row contains: + [hour_of_day, day_of_week, is_weekend, month] + + All values are numeric (float) for direct use in sklearn / XGBoost. + """ + features: List[List[float]] = [] + for ts in timestamps: + hour = float(ts.hour) + dow = float(ts.weekday()) # 0=Mon … 6=Sun + is_weekend = 1.0 if dow >= 5 else 0.0 + month = float(ts.month) + features.append([hour, dow, is_weekend, month]) + return features + + +def build_lag_features( + power_values: List[float], + lags: Optional[List[int]] = None, +) -> List[List[Optional[float]]]: + """Build lag feature vectors from a power time series. + + Default lags: 1h, 2h, 3h, 24h (same hour yesterday), 168h (same hour last week). + + Returns a list of rows; each row has one value per lag. + Positions where the lag is not available are filled with ``None``. + """ + if lags is None: + lags = [1, 2, 3, 24, 168] + + rows: List[List[Optional[float]]] = [] + for i in range(len(power_values)): + row: List[Optional[float]] = [] + for lag in lags: + idx = i - lag + row.append(power_values[idx] if idx >= 0 else None) + rows.append(row) + return rows + + +def prepare_supervised_dataset( + consumption: LoadEnergyConsumption, + hours_ahead: int = 3, + lags: Optional[List[int]] = None, +) -> Tuple[List[List[float]], List[float]]: + """Build X (features) and y (targets) from historical consumption. + + Each sample is one historical hour; features are calendar + lag; + target is the avg_power ``hours_ahead`` hours later. + + Rows where lags or target are unavailable are dropped. + + Returns (X, y) where X is a list of feature rows and y is a list of targets. + """ + if lags is None: + lags = [1, 2, 3, 24, 168] + + series = intervals_to_hourly_series(consumption) + series = fill_missing_hours(series) + + if not series: + return [], [] + + timestamps = [ts for ts, _ in series] + powers = [p for _, p in series] + calendar = build_calendar_features(timestamps) + lag_rows = build_lag_features(powers, lags=lags) + + max_lag = max(lags) if lags else 0 + + X: List[List[float]] = [] + y: List[float] = [] + + for i in range(max_lag, len(series) - hours_ahead): + lag_row = lag_rows[i] + # skip if any lag is None (should not happen past max_lag, but be safe) + if any(v is None for v in lag_row): + continue + features = calendar[i] + [float(v) for v in lag_row] # type: ignore[arg-type] + target = powers[i + hours_ahead] + X.append(features) + y.append(target) + + return X, y diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/naive_last_hour.py b/edge_mining/adapters/domain/home_load/forecast_providers/naive_last_hour.py new file mode 100644 index 0000000..1461f72 --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/naive_last_hour.py @@ -0,0 +1,108 @@ +"""NaiveLastHour forecast provider for energy load consumption.""" + +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.ports import EnergyLoadForecastProviderPort +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import EnergyLoadForecastProviderNaiveLastHourConfig +from edge_mining.shared.external_services.ports import ExternalServicePort +from edge_mining.shared.interfaces.config import Configuration +from edge_mining.shared.interfaces.factories import EnergyLoadForecastAdapterFactory +from edge_mining.shared.logging.port import LoggerPort + + +class NaiveLastHourForecastProviderFactory(EnergyLoadForecastAdapterFactory): + """Factory for creating a NaiveLastHourForecastProvider instance.""" + + def create( + self, + config: Optional[Configuration], + logger: Optional[LoggerPort], + external_service: Optional[ExternalServicePort], + ) -> "NaiveLastHourForecastProvider": + if config is not None and not isinstance(config, EnergyLoadForecastProviderNaiveLastHourConfig): + raise EnergyLoadForecastProviderError( + "Invalid configuration type for NaiveLastHour energy load forecast provider. " + "Expected EnergyLoadForecastProviderNaiveLastHourConfig." + ) + + hours_ahead = 3 + if isinstance(config, EnergyLoadForecastProviderNaiveLastHourConfig): + hours_ahead = config.hours_ahead + + return NaiveLastHourForecastProvider( + hours_ahead=hours_ahead, + logger=logger, + ) + + +class NaiveLastHourForecastProvider(EnergyLoadForecastProviderPort): + """Forecast by repeating the average power of the last hour for N hours ahead. + + This is the simplest non-trivial baseline: it assumes the near future will + look like the recent past. Always available as a fallback even with very + little historical data (only 1 hour needed). + """ + + def __init__(self, hours_ahead: int = 3, logger: Optional[LoggerPort] = None): + super().__init__(forecast_provider_type=EnergyLoadForecastProviderAdapter.NAIVE_LAST_HOUR) + self._hours_ahead = hours_ahead + self._logger = logger + + @property + def min_required_history_hours(self) -> int: # noqa: D102 + return 1 + + def get_consumption_forecast( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int = 3 + ) -> Optional[LoadEnergyConsumption]: + effective_hours = self._hours_ahead or hours_ahead + if effective_hours <= 0: + return None + + now = Timestamp(datetime.now(timezone.utc)) + + # Compute baseline from the last hour of history + last_hour = consumption_history.in_last_hours(1, now=now) + if last_hour.intervals: + baseline_power = last_hour.avg_power + elif consumption_history.intervals: + # Fallback: use overall average if last hour is empty + baseline_power = consumption_history.avg_power + else: + # No history at all — cannot forecast + return None + + if float(baseline_power) <= 0: + baseline_power = Watts(0.0) + + intervals: List[HomeLoadEnergyInterval] = [] + for i in range(effective_hours): + start = Timestamp(now + timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + point = HomeLoadPowerPoint(timestamp=start, power=baseline_power) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[point], + energy=WattHours(float(baseline_power)), + ) + ) + + forecast = LoadEnergyConsumption(timestamp=now, intervals=intervals) + + if self._logger: + self._logger.debug( + f"NaiveLastHourForecastProvider: baseline {baseline_power:.0f}W, " + f"{effective_hours}h ahead, total_energy={forecast.total_energy:.0f}Wh" + ) + return forecast diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/naive_persistence.py b/edge_mining/adapters/domain/home_load/forecast_providers/naive_persistence.py new file mode 100644 index 0000000..3ab742e --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/naive_persistence.py @@ -0,0 +1,129 @@ +"""NaivePersistence forecast provider for energy load consumption. + +Forecasts by repeating the consumption profile from the *same hours of the +previous day*. Unlike ``NaiveLastHour`` (which repeats a single recent average), +this provider preserves the intra-day shape of the load profile — capturing +morning peaks, afternoon dips, etc. + +Inspired by the "naive/persistence" method used in EMHASS. +""" + +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.ports import EnergyLoadForecastProviderPort +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import EnergyLoadForecastProviderNaivePersistenceConfig +from edge_mining.shared.external_services.ports import ExternalServicePort +from edge_mining.shared.interfaces.config import Configuration +from edge_mining.shared.interfaces.factories import EnergyLoadForecastAdapterFactory +from edge_mining.shared.logging.port import LoggerPort + + +class NaivePersistenceForecastProviderFactory(EnergyLoadForecastAdapterFactory): + """Factory for creating a NaivePersistenceForecastProvider instance.""" + + def create( + self, + config: Optional[Configuration], + logger: Optional[LoggerPort], + external_service: Optional[ExternalServicePort], + ) -> "NaivePersistenceForecastProvider": + if config is not None and not isinstance(config, EnergyLoadForecastProviderNaivePersistenceConfig): + raise EnergyLoadForecastProviderError( + "Invalid configuration type for NaivePersistence energy load forecast provider. " + "Expected EnergyLoadForecastProviderNaivePersistenceConfig." + ) + + hours_ahead = 24 + delta_days = 1 + if isinstance(config, EnergyLoadForecastProviderNaivePersistenceConfig): + hours_ahead = config.hours_ahead + delta_days = config.delta_days + + return NaivePersistenceForecastProvider( + hours_ahead=hours_ahead, + delta_days=delta_days, + logger=logger, + ) + + +class NaivePersistenceForecastProvider(EnergyLoadForecastProviderPort): + """Forecast by repeating the load profile from ``delta_days`` ago. + + For each future hour, this provider looks up the corresponding hour from + ``delta_days`` days in the past and uses that power value. If a specific + hour slot is missing from history, the overall history average is used as + fallback. + """ + + def __init__( + self, + hours_ahead: int = 24, + delta_days: int = 1, + logger: Optional[LoggerPort] = None, + ): + super().__init__(forecast_provider_type=EnergyLoadForecastProviderAdapter.NAIVE_PERSISTENCE) + self._hours_ahead = hours_ahead + self._delta_days = delta_days + self._logger = logger + + @property + def min_required_history_hours(self) -> int: # noqa: D102 + return self._delta_days * 24 + + def get_consumption_forecast( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int = 24 + ) -> Optional[LoadEnergyConsumption]: + effective_hours = self._hours_ahead + if effective_hours <= 0: + return None + + if not consumption_history.intervals: + return None + + now = Timestamp(datetime.now(timezone.utc)) + fallback_power = consumption_history.avg_power + + # Build an hour-of-day → power lookup from the reference day + reference_date = (now - timedelta(days=self._delta_days)).date() + hour_power: dict[int, float] = {} + for interval in consumption_history.intervals: + if interval.start.date() == reference_date: + hour_power[interval.start.hour] = float(interval.avg_power) + + intervals: List[HomeLoadEnergyInterval] = [] + for i in range(effective_hours): + start = Timestamp(now + timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + target_hour = start.hour + + power = Watts(hour_power.get(target_hour, float(fallback_power))) + if float(power) < 0: + power = Watts(0.0) + + point = HomeLoadPowerPoint(timestamp=start, power=power) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[point], + energy=WattHours(float(power)), + ) + ) + + forecast = LoadEnergyConsumption(timestamp=now, intervals=intervals) + + if self._logger: + self._logger.debug( + f"NaivePersistenceForecastProvider: delta_days={self._delta_days}, " + f"{effective_hours}h ahead, total_energy={forecast.total_energy:.0f}Wh" + ) + return forecast diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/seasonal_baseline.py b/edge_mining/adapters/domain/home_load/forecast_providers/seasonal_baseline.py new file mode 100644 index 0000000..54909dc --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/seasonal_baseline.py @@ -0,0 +1,130 @@ +"""SeasonalBaseline forecast provider for energy load consumption.""" + +from collections import defaultdict +from datetime import datetime, timedelta, timezone +from typing import Dict, List, Optional, Tuple + +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.ports import EnergyLoadForecastProviderPort +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import EnergyLoadForecastProviderSeasonalBaselineConfig +from edge_mining.shared.external_services.ports import ExternalServicePort +from edge_mining.shared.interfaces.config import Configuration +from edge_mining.shared.interfaces.factories import EnergyLoadForecastAdapterFactory +from edge_mining.shared.logging.port import LoggerPort + + +class SeasonalBaselineForecastProviderFactory(EnergyLoadForecastAdapterFactory): + """Factory for creating a SeasonalBaselineForecastProvider instance.""" + + def create( + self, + config: Optional[Configuration], + logger: Optional[LoggerPort], + external_service: Optional[ExternalServicePort], + ) -> "SeasonalBaselineForecastProvider": + if config is not None and not isinstance(config, EnergyLoadForecastProviderSeasonalBaselineConfig): + raise EnergyLoadForecastProviderError( + "Invalid configuration type for SeasonalBaseline energy load forecast provider. " + "Expected EnergyLoadForecastProviderSeasonalBaselineConfig." + ) + + hours_ahead = 3 + weeks_lookback = 4 + if isinstance(config, EnergyLoadForecastProviderSeasonalBaselineConfig): + hours_ahead = config.hours_ahead + weeks_lookback = config.weeks_lookback + + return SeasonalBaselineForecastProvider( + hours_ahead=hours_ahead, + weeks_lookback=weeks_lookback, + logger=logger, + ) + + +class SeasonalBaselineForecastProvider(EnergyLoadForecastProviderPort): + """Forecast by averaging historical power for each (hour_of_day, day_of_week) slot. + + Uses a configurable look-back window (default 4 weeks) to build a profile + of typical consumption per time slot. For CONTINUOUS and SEASONAL devices + this is a strong baseline. + + If insufficient data exists for a particular slot, falls back to the global + average across all available data. + """ + + def __init__( + self, + hours_ahead: int = 3, + weeks_lookback: int = 4, + logger: Optional[LoggerPort] = None, + ): + super().__init__(forecast_provider_type=EnergyLoadForecastProviderAdapter.SEASONAL_BASELINE) + self._hours_ahead = hours_ahead + self._weeks_lookback = weeks_lookback + self._logger = logger + + def get_consumption_forecast( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int = 3 + ) -> Optional[LoadEnergyConsumption]: + effective_hours = self._hours_ahead or hours_ahead + if effective_hours <= 0: + return None + + if not consumption_history.intervals: + return None + + # Build seasonal profile: (day_of_week, hour_of_day) → list of avg_power + profile: Dict[Tuple[int, int], List[float]] = defaultdict(list) + for interval in consumption_history.intervals: + dow = interval.start.weekday() # 0=Monday + hod = interval.start.hour + power = float(interval.avg_power) + if power > 0: + profile[(dow, hod)].append(power) + + if not profile: + return None + + # Global fallback: average of all observed power values + all_powers = [p for powers in profile.values() for p in powers] + global_avg = sum(all_powers) / len(all_powers) if all_powers else 0.0 + + now = Timestamp(datetime.now(timezone.utc)) + intervals: List[HomeLoadEnergyInterval] = [] + for i in range(effective_hours): + start = Timestamp(now + timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + + dow = start.weekday() + hod = start.hour + slot_values = profile.get((dow, hod)) + if slot_values: + slot_power = Watts(sum(slot_values) / len(slot_values)) + else: + slot_power = Watts(global_avg) + + point = HomeLoadPowerPoint(timestamp=start, power=slot_power) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[point], + energy=WattHours(float(slot_power)), + ) + ) + + forecast = LoadEnergyConsumption(timestamp=now, intervals=intervals) + + if self._logger: + self._logger.debug( + f"SeasonalBaselineForecastProvider: {len(profile)} slots, " + f"{effective_hours}h ahead, total_energy={forecast.total_energy:.0f}Wh" + ) + return forecast diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/skforecast_provider.py b/edge_mining/adapters/domain/home_load/forecast_providers/skforecast_provider.py new file mode 100644 index 0000000..ce83daf --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/skforecast_provider.py @@ -0,0 +1,410 @@ +"""Skforecast ForecasterRecursive provider for energy load consumption. + +Uses ``skforecast.recursive.ForecasterRecursive`` with a configurable +scikit-learn regressor backend. The forecaster handles auto-regressive +multi-step prediction natively: it feeds its own predictions back as input +for subsequent steps. + +Supported sklearn models (selected via ``sklearn_model`` config string): + RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor, + KNeighborsRegressor, Ridge, Lasso, ElasticNet, AdaBoostRegressor, + MLPRegressor, SVR. + +If a pre-trained model exists in ``model_repo`` it is loaded. Otherwise +the provider fits on-the-fly from the supplied consumption history. +""" + +import pickle +from datetime import datetime, timedelta, timezone +from typing import Dict, List, Optional, Type + +from edge_mining.domain.common import EntityId, Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.ports import EnergyLoadForecastProviderPort, LoadConsumptionModelRepository +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import EnergyLoadForecastProviderSkforecastConfig +from edge_mining.shared.external_services.ports import ExternalServicePort +from edge_mining.shared.interfaces.config import Configuration +from edge_mining.shared.interfaces.factories import EnergyLoadForecastAdapterFactory +from edge_mining.shared.logging.port import LoggerPort + +from .features import fill_missing_hours, intervals_to_hourly_series + +# --------------------------------------------------------------------------- +# Lazy imports — heavy dependencies +# --------------------------------------------------------------------------- +_SKFORECAST_AVAILABLE = False +try: + import pandas as pd + from skforecast.recursive import ForecasterRecursive + + _SKFORECAST_AVAILABLE = True +except ImportError: + pd = None # type: ignore[assignment] + ForecasterRecursive = None # type: ignore[assignment,misc] + +# --------------------------------------------------------------------------- +# Mapping from config string → sklearn class (lazy-resolved) +# --------------------------------------------------------------------------- +_SKLEARN_MODEL_REGISTRY: Dict[str, str] = { + "RandomForestRegressor": "sklearn.ensemble.RandomForestRegressor", + "GradientBoostingRegressor": "sklearn.ensemble.GradientBoostingRegressor", + "ExtraTreesRegressor": "sklearn.ensemble.ExtraTreesRegressor", + "AdaBoostRegressor": "sklearn.ensemble.AdaBoostRegressor", + "KNeighborsRegressor": "sklearn.neighbors.KNeighborsRegressor", + "Ridge": "sklearn.linear_model.Ridge", + "Lasso": "sklearn.linear_model.Lasso", + "ElasticNet": "sklearn.linear_model.ElasticNet", + "MLPRegressor": "sklearn.neural_network.MLPRegressor", + "SVR": "sklearn.svm.SVR", +} + + +def _resolve_sklearn_model(name: str) -> object: + """Instantiate an sklearn regressor by its class name.""" + import importlib + + fqn = _SKLEARN_MODEL_REGISTRY.get(name) + if fqn is None: + raise EnergyLoadForecastProviderError( + f"Unsupported sklearn model '{name}'. Available: {list(_SKLEARN_MODEL_REGISTRY.keys())}" + ) + module_path, class_name = fqn.rsplit(".", 1) + module = importlib.import_module(module_path) + cls: Type = getattr(module, class_name) + return cls() + + +class SkforecastForecastProviderFactory(EnergyLoadForecastAdapterFactory): + """Factory for creating a SkforecastForecastProvider instance.""" + + def __init__(self, model_repo: Optional[LoadConsumptionModelRepository] = None) -> None: + self._model_repo = model_repo + + def create( + self, + config: Optional[Configuration], + logger: Optional[LoggerPort], + external_service: Optional[ExternalServicePort], + ) -> "SkforecastForecastProvider": + if config is not None and not isinstance(config, EnergyLoadForecastProviderSkforecastConfig): + raise EnergyLoadForecastProviderError( + "Invalid configuration type for Skforecast energy load forecast provider. " + "Expected EnergyLoadForecastProviderSkforecastConfig." + ) + + hours_ahead = 24 + weeks_lookback = 8 + sklearn_model = "RandomForestRegressor" + num_lags = 72 + if isinstance(config, EnergyLoadForecastProviderSkforecastConfig): + hours_ahead = config.hours_ahead + weeks_lookback = config.weeks_lookback + sklearn_model = config.sklearn_model + num_lags = config.num_lags + + return SkforecastForecastProvider( + hours_ahead=hours_ahead, + weeks_lookback=weeks_lookback, + sklearn_model=sklearn_model, + num_lags=num_lags, + model_repo=self._model_repo, + logger=logger, + ) + + +class SkforecastForecastProvider(EnergyLoadForecastProviderPort): + """Forecast provider using skforecast ForecasterRecursive. + + Uses a configurable sklearn regressor wrapped in ``ForecasterRecursive`` + which automatically manages lag features and recursive multi-step + prediction. + + If a pre-trained model is available in ``model_repo``, it is loaded and + used directly. Otherwise, fits on-the-fly from the provided history. + """ + + def __init__( + self, + hours_ahead: int = 24, + weeks_lookback: int = 8, + sklearn_model: str = "RandomForestRegressor", + num_lags: int = 72, + model_repo: Optional[LoadConsumptionModelRepository] = None, + device_id: Optional[EntityId] = None, + logger: Optional[LoggerPort] = None, + ): + super().__init__(forecast_provider_type=EnergyLoadForecastProviderAdapter.SKFORECAST) + self._hours_ahead = hours_ahead + self._weeks_lookback = weeks_lookback + self._sklearn_model = sklearn_model + self._num_lags = num_lags + self._model_repo = model_repo + self._device_id = device_id + self._logger = logger + + @property + def min_required_history_hours(self) -> int: # noqa: D102 + return self._num_lags + 48 + self._hours_ahead + + def get_consumption_forecast( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int = 24 + ) -> Optional[LoadEnergyConsumption]: + if not _SKFORECAST_AVAILABLE: + if self._logger: + self._logger.warning("skforecast is not installed. Skipping Skforecast forecast.") + return None + + effective_hours = self._hours_ahead + if effective_hours <= 0: + return None + + if not consumption_history.intervals: + return None + + # Try saved model first + forecast = self._predict_from_saved_model(effective_hours) + if forecast is not None: + return forecast + + # Fallback: fit on-the-fly + return self._fit_and_predict(consumption_history, effective_hours) + + def _predict_from_saved_model(self, steps: int) -> Optional[LoadEnergyConsumption]: + """Load a pre-trained ForecasterRecursive from model_repo and predict.""" + if self._model_repo is None: + return None + + model_entity = self._model_repo.get_active_model(EnergyLoadForecastProviderAdapter.SKFORECAST, self._device_id) + if model_entity is None or model_entity.model_bytes is None: + return None + + try: + forecaster = pickle.loads(model_entity.model_bytes) # noqa: S301 + predictions = forecaster.predict(steps=steps) + return self._build_forecast(predictions.tolist()) + except Exception as exc: + if self._logger: + self._logger.warning(f"Failed to predict from saved skforecast model: {exc}") + return None + + def _fit_and_predict( + self, consumption_history: LoadEnergyConsumption, steps: int + ) -> Optional[LoadEnergyConsumption]: + """Fit ForecasterRecursive on the fly and predict.""" + series = intervals_to_hourly_series(consumption_history) + series = fill_missing_hours(series) + powers = [p for _, p in series] + + if len(powers) < self._num_lags + steps: + if self._logger: + self._logger.debug( + f"Insufficient data for skforecast: {len(powers)} points, " + f"need {self._num_lags + steps} (lags + steps)." + ) + return None + + try: + regressor = _resolve_sklearn_model(self._sklearn_model) + forecaster = ForecasterRecursive(estimator=regressor, lags=self._num_lags) + + y = pd.Series(powers, name="power") + forecaster.fit(y=y) + + predictions = forecaster.predict(steps=steps) + return self._build_forecast(predictions.tolist()) + except Exception as exc: + if self._logger: + self._logger.warning(f"Skforecast on-the-fly fit failed: {exc}") + return None + + @staticmethod + def _build_forecast(predictions: List[float]) -> LoadEnergyConsumption: + """Convert a list of predicted power values to LoadEnergyConsumption.""" + now = Timestamp(datetime.now(timezone.utc)) + intervals: List[HomeLoadEnergyInterval] = [] + for i, power_val in enumerate(predictions): + start = Timestamp(now + timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(max(0.0, float(power_val))) + point = HomeLoadPowerPoint(timestamp=start, power=power) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[point], + energy=WattHours(float(power)), + ) + ) + return LoadEnergyConsumption(timestamp=now, intervals=intervals) + + @staticmethod + def tune( + y_series: "pd.Series", + sklearn_model_name: str = "RandomForestRegressor", + num_lags: int = 72, + steps: int = 24, + n_trials: int = 20, + metric: str = "mean_absolute_error", + ) -> tuple: + """Run Bayesian hyperparameter optimisation via Optuna. + + Returns ``(best_params, tuned_forecaster)`` where *best_params* is a + dict of the winning hyperparameter combination and *tuned_forecaster* + is the ``ForecasterRecursive`` already refit with those params. + + This is a **static helper** so it can be called from the training + service without instantiating a full provider. + """ + import optuna + from skforecast.model_selection import TimeSeriesFold, bayesian_search_forecaster + + optuna.logging.set_verbosity(optuna.logging.WARNING) + + regressor = _resolve_sklearn_model(sklearn_model_name) + forecaster = ForecasterRecursive(estimator=regressor, lags=num_lags) + + cv = TimeSeriesFold( + steps=steps, + initial_train_size=len(y_series) - steps * 2, + refit=False, + fixed_train_size=False, + ) + + search_space = _build_search_space(sklearn_model_name) + + results_df, _study = bayesian_search_forecaster( + forecaster=forecaster, + y=y_series, + cv=cv, + search_space=search_space, + metric=metric, + n_trials=n_trials, + return_best=True, + verbose=False, + show_progress=False, + ) + + best_params = results_df.iloc[0].to_dict() if not results_df.empty else {} + # Keep only hyperparameter keys (filter out metric columns) + param_keys = {k for k in best_params if k not in ("mean_absolute_error", "mean_squared_error", metric)} + best_params = {k: v for k, v in best_params.items() if k in param_keys} + + # return_best=True refits the forecaster in-place with the best params + return best_params, forecaster + + @staticmethod + def backtest( + forecaster: "ForecasterRecursive", + y_series: "pd.Series", + steps: int = 24, + folds: int = 3, + metric: str = "mean_absolute_error", + ) -> dict: + """Run rolling-window backtesting on an already-fit forecaster. + + Returns a dict with ``backtest_mae``, ``backtest_rmse`` and + ``backtest_folds``. + """ + import numpy as np + from skforecast.model_selection import TimeSeriesFold, backtesting_forecaster + + # Need at least window_size + steps*(folds+1) data points + window = getattr(forecaster, "window_size", steps) + min_required = window + steps * (folds + 1) + if len(y_series) < min_required: + return {"backtest_mae": None, "backtest_rmse": None, "backtest_folds": 0} + + initial_train_size = len(y_series) - steps * folds + if initial_train_size <= window: + return {"backtest_mae": None, "backtest_rmse": None, "backtest_folds": 0} + + cv = TimeSeriesFold( + steps=steps, + initial_train_size=initial_train_size, + refit=False, + fixed_train_size=False, + ) + + metric_values, predictions = backtesting_forecaster( + forecaster=forecaster, + y=y_series, + cv=cv, + metric=[metric, "mean_squared_error"], + verbose=False, + show_progress=False, + ) + + # metric_values is a DataFrame with one row, columns = metric names + bt_mae = float(metric_values[metric].iloc[0]) if metric in metric_values.columns else None + bt_mse = ( + float(metric_values["mean_squared_error"].iloc[0]) + if "mean_squared_error" in metric_values.columns + else None + ) + bt_rmse = float(np.sqrt(bt_mse)) if bt_mse is not None else None + + # Number of folds = number of complete prediction windows + actual_folds = len(predictions) // steps if len(predictions) >= steps else 0 + + return { + "backtest_mae": bt_mae, + "backtest_rmse": bt_rmse, + "backtest_folds": actual_folds, + } + + +def _build_search_space(sklearn_model_name: str): + """Return an Optuna search_space callable for the given model.""" + import optuna + + def _rf_space(trial: optuna.Trial) -> dict: + return { + "n_estimators": trial.suggest_int("n_estimators", 50, 400), + "max_depth": trial.suggest_int("max_depth", 3, 20), + "min_samples_leaf": trial.suggest_int("min_samples_leaf", 1, 10), + "lags": trial.suggest_categorical("lags", [24, 48, 72]), + } + + def _gb_space(trial: optuna.Trial) -> dict: + return { + "n_estimators": trial.suggest_int("n_estimators", 50, 400), + "max_depth": trial.suggest_int("max_depth", 3, 15), + "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True), + "lags": trial.suggest_categorical("lags", [24, 48, 72]), + } + + def _ridge_space(trial: optuna.Trial) -> dict: + return { + "alpha": trial.suggest_float("alpha", 0.01, 100.0, log=True), + "lags": trial.suggest_categorical("lags", [24, 48, 72]), + } + + def _knn_space(trial: optuna.Trial) -> dict: + return { + "n_neighbors": trial.suggest_int("n_neighbors", 3, 30), + "weights": trial.suggest_categorical("weights", ["uniform", "distance"]), + "lags": trial.suggest_categorical("lags", [24, 48, 72]), + } + + def _default_space(trial: optuna.Trial) -> dict: + return { + "lags": trial.suggest_categorical("lags", [24, 48, 72]), + } + + space_map = { + "RandomForestRegressor": _rf_space, + "ExtraTreesRegressor": _rf_space, + "GradientBoostingRegressor": _gb_space, + "AdaBoostRegressor": _gb_space, + "Ridge": _ridge_space, + "Lasso": _ridge_space, + "ElasticNet": _ridge_space, + "KNeighborsRegressor": _knn_space, + } + return space_map.get(sklearn_model_name, _default_space) diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/statsmodels_hw.py b/edge_mining/adapters/domain/home_load/forecast_providers/statsmodels_hw.py new file mode 100644 index 0000000..d8d1b87 --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/statsmodels_hw.py @@ -0,0 +1,194 @@ +"""Statsmodels (Holt-Winters / SARIMA) forecast provider for energy load consumption.""" + +import pickle +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from edge_mining.domain.common import EntityId, Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.ports import EnergyLoadForecastProviderPort, LoadConsumptionModelRepository +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import EnergyLoadForecastProviderStatsmodelsConfig +from edge_mining.shared.external_services.ports import ExternalServicePort +from edge_mining.shared.interfaces.config import Configuration +from edge_mining.shared.interfaces.factories import EnergyLoadForecastAdapterFactory +from edge_mining.shared.logging.port import LoggerPort + +from .features import fill_missing_hours, intervals_to_hourly_series + +# Lazy imports to avoid hard dependency when [ml] extras are not installed. +_HW_AVAILABLE = False +try: + from statsmodels.tsa.holtwinters import ExponentialSmoothing + + _HW_AVAILABLE = True +except ImportError: + ExponentialSmoothing = None # type: ignore[misc,assignment] + + +class StatsmodelsForecastProviderFactory(EnergyLoadForecastAdapterFactory): + """Factory for creating a StatsmodelsForecastProvider instance.""" + + def __init__(self, model_repo: Optional[LoadConsumptionModelRepository] = None) -> None: + self._model_repo = model_repo + + def create( + self, + config: Optional[Configuration], + logger: Optional[LoggerPort], + external_service: Optional[ExternalServicePort], + ) -> "StatsmodelsForecastProvider": + if config is not None and not isinstance(config, EnergyLoadForecastProviderStatsmodelsConfig): + raise EnergyLoadForecastProviderError( + "Invalid configuration type for Statsmodels energy load forecast provider. " + "Expected EnergyLoadForecastProviderStatsmodelsConfig." + ) + + hours_ahead = 3 + weeks_lookback = 8 + seasonal_periods = 24 + if isinstance(config, EnergyLoadForecastProviderStatsmodelsConfig): + hours_ahead = config.hours_ahead + weeks_lookback = config.weeks_lookback + seasonal_periods = config.seasonal_periods + + return StatsmodelsForecastProvider( + hours_ahead=hours_ahead, + weeks_lookback=weeks_lookback, + seasonal_periods=seasonal_periods, + model_repo=self._model_repo, + logger=logger, + ) + + +class StatsmodelsForecastProvider(EnergyLoadForecastProviderPort): + """Forecast provider using Holt-Winters exponential smoothing from statsmodels. + + If a pre-trained model exists in ``model_repo`` it will be used. + Otherwise the provider fits a new model on-the-fly from the supplied history + (slower, but always works as a fallback). + """ + + def __init__( + self, + hours_ahead: int = 3, + weeks_lookback: int = 8, + seasonal_periods: int = 24, + model_repo: Optional[LoadConsumptionModelRepository] = None, + device_id: Optional[EntityId] = None, + logger: Optional[LoggerPort] = None, + ): + super().__init__(forecast_provider_type=EnergyLoadForecastProviderAdapter.STATSMODELS) + self._hours_ahead = hours_ahead + self._weeks_lookback = weeks_lookback + self._seasonal_periods = seasonal_periods + self._model_repo = model_repo + self._device_id = device_id + self._logger = logger + + @property + def min_required_history_hours(self) -> int: # noqa: D102 + return self._seasonal_periods * 2 + + def get_consumption_forecast( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int = 3 + ) -> Optional[LoadEnergyConsumption]: + if not _HW_AVAILABLE: + raise EnergyLoadForecastProviderError( + "statsmodels is not installed. Install the [ml] extras to enable Holt-Winters forecasting." + ) + + effective_hours = self._hours_ahead or hours_ahead + if effective_hours <= 0: + return None + + # Try to load a pre-trained model + predictions = self._predict_from_saved_model(effective_hours) + + # Fallback: fit on-the-fly + if predictions is None: + predictions = self._fit_and_predict(consumption_history, effective_hours) + + if predictions is None: + return None + + return self._build_forecast(predictions) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _predict_from_saved_model(self, hours_ahead: int) -> Optional[List[float]]: + """Try to load a saved model from the repository and forecast.""" + if self._model_repo is None: + return None + model_entity = self._model_repo.get_active_model( + adapter_type=EnergyLoadForecastProviderAdapter.STATSMODELS, + device_id=self._device_id, + ) + if model_entity is None or model_entity.model_bytes is None: + return None + try: + fitted = pickle.loads(model_entity.model_bytes) # noqa: S301 + forecast = fitted.forecast(hours_ahead) + return [max(0.0, float(v)) for v in forecast] + except Exception as exc: + if self._logger: + self._logger.warning(f"Failed to use saved statsmodels model: {exc}") + return None + + def _fit_and_predict(self, consumption_history: LoadEnergyConsumption, hours_ahead: int) -> Optional[List[float]]: + """Fit a Holt-Winters model on the provided history and forecast.""" + series = intervals_to_hourly_series(consumption_history) + series = fill_missing_hours(series) + + if len(series) < self._seasonal_periods * 2: + raise EnergyLoadForecastProviderError( + f"Insufficient data for Holt-Winters forecasting: {len(series)} hourly data points " + f"available, but at least {self._seasonal_periods * 2} are required. " + f"Collect more history before requesting a forecast." + ) + + # Limit lookback + max_points = self._weeks_lookback * 7 * 24 + if len(series) > max_points: + series = series[-max_points:] + + powers = [p for _, p in series] + + try: + model = ExponentialSmoothing( + powers, + trend="add", + seasonal="add", + seasonal_periods=self._seasonal_periods, + ) + fitted = model.fit(optimized=True) + forecast = fitted.forecast(hours_ahead) + return [max(0.0, float(v)) for v in forecast] + except Exception as exc: + raise EnergyLoadForecastProviderError(f"Holt-Winters model fitting failed: {exc}") from exc + + def _build_forecast(self, predictions: List[float]) -> LoadEnergyConsumption: + """Convert a list of predicted avg_power values to LoadEnergyConsumption.""" + now = Timestamp(datetime.now(timezone.utc)) + intervals: List[HomeLoadEnergyInterval] = [] + for i, power_val in enumerate(predictions): + start = Timestamp(now + timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(power_val) + point = HomeLoadPowerPoint(timestamp=start, power=power) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[point], + energy=WattHours(power_val), + ) + ) + return LoadEnergyConsumption(timestamp=now, intervals=intervals) diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/typical_profile.py b/edge_mining/adapters/domain/home_load/forecast_providers/typical_profile.py new file mode 100644 index 0000000..023a33c --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/typical_profile.py @@ -0,0 +1,154 @@ +"""TypicalProfile forecast provider for energy load consumption. + +Forecasts by computing the "typical" consumption profile: historical data is +grouped by **(month, day_of_week, hour_of_day)** and averaged. This captures +both weekly patterns (workday vs. weekend) *and* seasonal variation (summer vs. +winter) — more granular than ``SeasonalBaseline`` which only uses (dow, hour). +""" + +from collections import defaultdict +from datetime import datetime, timedelta, timezone +from typing import Dict, List, Optional, Tuple + +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.ports import EnergyLoadForecastProviderPort +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import EnergyLoadForecastProviderTypicalProfileConfig +from edge_mining.shared.external_services.ports import ExternalServicePort +from edge_mining.shared.interfaces.config import Configuration +from edge_mining.shared.interfaces.factories import EnergyLoadForecastAdapterFactory +from edge_mining.shared.logging.port import LoggerPort + + +class TypicalProfileForecastProviderFactory(EnergyLoadForecastAdapterFactory): + """Factory for creating a TypicalProfileForecastProvider instance.""" + + def create( + self, + config: Optional[Configuration], + logger: Optional[LoggerPort], + external_service: Optional[ExternalServicePort], + ) -> "TypicalProfileForecastProvider": + if config is not None and not isinstance(config, EnergyLoadForecastProviderTypicalProfileConfig): + raise EnergyLoadForecastProviderError( + "Invalid configuration type for TypicalProfile energy load forecast provider. " + "Expected EnergyLoadForecastProviderTypicalProfileConfig." + ) + + hours_ahead = 24 + weeks_lookback = 8 + if isinstance(config, EnergyLoadForecastProviderTypicalProfileConfig): + hours_ahead = config.hours_ahead + weeks_lookback = config.weeks_lookback + + return TypicalProfileForecastProvider( + hours_ahead=hours_ahead, + weeks_lookback=weeks_lookback, + logger=logger, + ) + + +class TypicalProfileForecastProvider(EnergyLoadForecastProviderPort): + """Forecast by averaging historical power for each (month, dow, hour) slot. + + Compared to ``SeasonalBaseline`` (which groups by ``(dow, hour)`` only), + this provider also factors in the **month**, so the profile naturally adapts + to seasonal consumption changes (heating in winter, AC in summer, etc.). + + If insufficient data exists for the exact (month, dow, hour) triplet, the + provider falls back to (dow, hour) and then to the global average. + """ + + def __init__( + self, + hours_ahead: int = 24, + weeks_lookback: int = 8, + logger: Optional[LoggerPort] = None, + ): + super().__init__(forecast_provider_type=EnergyLoadForecastProviderAdapter.TYPICAL_PROFILE) + self._hours_ahead = hours_ahead + self._weeks_lookback = weeks_lookback + self._logger = logger + + @property + def min_required_history_hours(self) -> int: # noqa: D102 + return self._weeks_lookback * 168 # weeks × 168 h/week + + def get_consumption_forecast( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int = 24 + ) -> Optional[LoadEnergyConsumption]: + effective_hours = self._hours_ahead + if effective_hours <= 0: + return None + + if not consumption_history.intervals: + return None + + # Build profiles at two granularity levels + # Level 1 (precise): (month, dow, hour) → list[power] + profile_mdh: Dict[Tuple[int, int, int], List[float]] = defaultdict(list) + # Level 2 (fallback): (dow, hour) → list[power] + profile_dh: Dict[Tuple[int, int], List[float]] = defaultdict(list) + + for interval in consumption_history.intervals: + month = interval.start.month + dow = interval.start.weekday() + hod = interval.start.hour + power = float(interval.avg_power) + if power >= 0: + profile_mdh[(month, dow, hod)].append(power) + profile_dh[(dow, hod)].append(power) + + if not profile_dh: + return None + + # Global fallback + all_powers = [p for powers in profile_dh.values() for p in powers] + global_avg = sum(all_powers) / len(all_powers) if all_powers else 0.0 + + now = Timestamp(datetime.now(timezone.utc)) + intervals: List[HomeLoadEnergyInterval] = [] + for i in range(effective_hours): + start = Timestamp(now + timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + + month = start.month + dow = start.weekday() + hod = start.hour + + # Try precise (month, dow, hour) first, then (dow, hour), then global + mdh_values = profile_mdh.get((month, dow, hod)) + if mdh_values: + slot_power = Watts(sum(mdh_values) / len(mdh_values)) + else: + dh_values = profile_dh.get((dow, hod)) + if dh_values: + slot_power = Watts(sum(dh_values) / len(dh_values)) + else: + slot_power = Watts(global_avg) + + point = HomeLoadPowerPoint(timestamp=start, power=slot_power) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[point], + energy=WattHours(float(slot_power)), + ) + ) + + forecast = LoadEnergyConsumption(timestamp=now, intervals=intervals) + + if self._logger: + self._logger.debug( + f"TypicalProfileForecastProvider: {len(profile_mdh)} (m,d,h) slots, " + f"{len(profile_dh)} (d,h) slots, " + f"{effective_hours}h ahead, total_energy={forecast.total_energy:.0f}Wh" + ) + return forecast diff --git a/edge_mining/adapters/domain/home_load/forecast_providers/xgboost_provider.py b/edge_mining/adapters/domain/home_load/forecast_providers/xgboost_provider.py new file mode 100644 index 0000000..ccf1be0 --- /dev/null +++ b/edge_mining/adapters/domain/home_load/forecast_providers/xgboost_provider.py @@ -0,0 +1,233 @@ +"""XGBoost forecast provider for energy load consumption.""" + +import pickle +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from edge_mining.domain.common import EntityId, Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.ports import EnergyLoadForecastProviderPort, LoadConsumptionModelRepository +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import EnergyLoadForecastProviderXGBoostConfig +from edge_mining.shared.external_services.ports import ExternalServicePort +from edge_mining.shared.interfaces.config import Configuration +from edge_mining.shared.interfaces.factories import EnergyLoadForecastAdapterFactory +from edge_mining.shared.logging.port import LoggerPort + +from .features import ( + build_calendar_features, + fill_missing_hours, + intervals_to_hourly_series, + prepare_supervised_dataset, +) + +# Lazy imports +_XGB_AVAILABLE = False +try: + import xgboost as xgb + + _XGB_AVAILABLE = True +except ImportError: + xgb = None # type: ignore[assignment] + + +class XGBoostForecastProviderFactory(EnergyLoadForecastAdapterFactory): + """Factory for creating an XGBoostForecastProvider instance.""" + + def __init__(self, model_repo: Optional[LoadConsumptionModelRepository] = None) -> None: + self._model_repo = model_repo + + def create( + self, + config: Optional[Configuration], + logger: Optional[LoggerPort], + external_service: Optional[ExternalServicePort], + ) -> "XGBoostForecastProvider": + if config is not None and not isinstance(config, EnergyLoadForecastProviderXGBoostConfig): + raise EnergyLoadForecastProviderError( + "Invalid configuration type for XGBoost energy load forecast provider. " + "Expected EnergyLoadForecastProviderXGBoostConfig." + ) + + hours_ahead = 3 + weeks_lookback = 8 + n_estimators = 100 + max_depth = 6 + learning_rate = 0.1 + if isinstance(config, EnergyLoadForecastProviderXGBoostConfig): + hours_ahead = config.hours_ahead + weeks_lookback = config.weeks_lookback + n_estimators = config.n_estimators + max_depth = config.max_depth + learning_rate = config.learning_rate + + return XGBoostForecastProvider( + hours_ahead=hours_ahead, + weeks_lookback=weeks_lookback, + n_estimators=n_estimators, + max_depth=max_depth, + learning_rate=learning_rate, + model_repo=self._model_repo, + logger=logger, + ) + + +class XGBoostForecastProvider(EnergyLoadForecastProviderPort): + """Forecast provider using XGBoost gradient boosting. + + Uses calendar features (hour, day-of-week, is-weekend, month) and lag + features (1h, 2h, 3h, 24h, 168h) to predict avg_power per future hour. + + If a pre-trained model exists in ``model_repo`` it is used. + Otherwise, fits on-the-fly from the supplied history. + """ + + def __init__( + self, + hours_ahead: int = 3, + weeks_lookback: int = 8, + n_estimators: int = 100, + max_depth: int = 6, + learning_rate: float = 0.1, + model_repo: Optional[LoadConsumptionModelRepository] = None, + device_id: Optional[EntityId] = None, + logger: Optional[LoggerPort] = None, + ): + super().__init__(forecast_provider_type=EnergyLoadForecastProviderAdapter.XGBOOST) + self._hours_ahead = hours_ahead + self._weeks_lookback = weeks_lookback + self._n_estimators = n_estimators + self._max_depth = max_depth + self._learning_rate = learning_rate + self._model_repo = model_repo + self._device_id = device_id + self._logger = logger + + @property + def min_required_history_hours(self) -> int: # noqa: D102 + # max lag (168h) + minimum training samples (48) + forecast horizon + return 168 + 48 + self._hours_ahead + + def get_consumption_forecast( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int = 3 + ) -> Optional[LoadEnergyConsumption]: + if not _XGB_AVAILABLE: + if self._logger: + self._logger.warning("xgboost is not installed — cannot produce forecast") + return None + + effective_hours = self._hours_ahead or hours_ahead + if effective_hours <= 0: + return None + + # Try saved model first + predictions = self._predict_from_saved_model(consumption_history, effective_hours) + + # Fallback: fit on-the-fly + if predictions is None: + predictions = self._fit_and_predict(consumption_history, effective_hours) + + if predictions is None: + return None + + return self._build_forecast(predictions) + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + def _predict_from_saved_model( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int + ) -> Optional[List[float]]: + if self._model_repo is None: + return None + model_entity = self._model_repo.get_active_model( + adapter_type=EnergyLoadForecastProviderAdapter.XGBOOST, + device_id=self._device_id, + ) + if model_entity is None or model_entity.model_bytes is None: + return None + try: + saved_model = pickle.loads(model_entity.model_bytes) # noqa: S301 + return self._predict_future(saved_model, consumption_history, hours_ahead) + except Exception as exc: + if self._logger: + self._logger.warning(f"Failed to use saved XGBoost model: {exc}") + return None + + def _fit_and_predict(self, consumption_history: LoadEnergyConsumption, hours_ahead: int) -> Optional[List[float]]: + X, y = prepare_supervised_dataset(consumption_history, hours_ahead=hours_ahead) + if len(X) < 48: + if self._logger: + self._logger.debug(f"Insufficient training data for XGBoost ({len(X)} samples, need 48)") + return None + + try: + model = xgb.XGBRegressor( + n_estimators=self._n_estimators, + max_depth=self._max_depth, + learning_rate=self._learning_rate, + objective="reg:squarederror", + verbosity=0, + ) + model.fit(X, y) + return self._predict_future(model, consumption_history, hours_ahead) + except Exception as exc: + if self._logger: + self._logger.warning(f"XGBoost fit failed: {exc}") + return None + + def _predict_future( + self, model: "xgb.XGBRegressor", consumption_history: LoadEnergyConsumption, hours_ahead: int + ) -> Optional[List[float]]: + """Build feature rows for the next N hours and predict.""" + series = intervals_to_hourly_series(consumption_history) + series = fill_missing_hours(series) + if not series: + return None + + powers = [p for _, p in series] + now = datetime.now(timezone.utc) + lags = [1, 2, 3, 24, 168] + + predictions: List[float] = [] + # Iteratively predict one step at a time, appending predictions to powers + extended_powers = list(powers) + for step in range(hours_ahead): + future_ts = now + timedelta(hours=step) + cal = build_calendar_features([future_ts])[0] + lag_row = [] + n = len(extended_powers) + for lag in lags: + idx = n - lag + lag_row.append(extended_powers[idx] if idx >= 0 else 0.0) + feature_row = [cal + lag_row] + pred = float(model.predict(feature_row)[0]) + pred = max(0.0, pred) + predictions.append(pred) + extended_powers.append(pred) + + return predictions + + def _build_forecast(self, predictions: List[float]) -> LoadEnergyConsumption: + now = Timestamp(datetime.now(timezone.utc)) + intervals: List[HomeLoadEnergyInterval] = [] + for i, power_val in enumerate(predictions): + start = Timestamp(now + timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(power_val) + point = HomeLoadPowerPoint(timestamp=start, power=power) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[point], + energy=WattHours(power_val), + ) + ) + return LoadEnergyConsumption(timestamp=now, intervals=intervals) diff --git a/edge_mining/adapters/domain/home_load/history_providers/__init__.py b/edge_mining/adapters/domain/home_load/history_providers/__init__.py new file mode 100644 index 0000000..d501d6a --- /dev/null +++ b/edge_mining/adapters/domain/home_load/history_providers/__init__.py @@ -0,0 +1 @@ +"""Collection of home load history provider adapters.""" diff --git a/edge_mining/adapters/domain/home_load/history_providers/dummy.py b/edge_mining/adapters/domain/home_load/history_providers/dummy.py new file mode 100644 index 0000000..198aa1c --- /dev/null +++ b/edge_mining/adapters/domain/home_load/history_providers/dummy.py @@ -0,0 +1,41 @@ +"""Dummy adapter that serves cached power points from the history repository.""" + +from typing import List, Optional + +from edge_mining.adapters.domain.home_load.history_providers.helpers import group_power_points_into_intervals +from edge_mining.domain.common import EntityId, Timestamp +from edge_mining.domain.home_load.common import EnergyLoadHistoryProviderAdapter +from edge_mining.domain.home_load.ports import EnergyLoadHistoryProviderPort, EnergyLoadHistoryRepository +from edge_mining.domain.home_load.value_objects import HomeLoadEnergyInterval, HomeLoadPowerPoint +from edge_mining.shared.logging.port import LoggerPort + + +class DummyEnergyLoadHistoryProvider(EnergyLoadHistoryProviderPort): + """Dummy history provider that reads directly from the history repository. + + No external fetching — it just serves whatever has already been ingested + into the repo for the bound device. Useful for testing and as a fallback. + """ + + def __init__( + self, + device_id: EntityId, + history_repo: EnergyLoadHistoryRepository, + logger: Optional[LoggerPort] = None, + ): + super().__init__(device_id=device_id, provider_type=EnergyLoadHistoryProviderAdapter.DUMMY) + self._history_repo = history_repo + self._logger = logger + + async def get_power_points(self, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + """Return cached power points for this device in [start, end).""" + if self._logger: + self._logger.debug(f"DummyEnergyLoadHistoryProvider: get_power_points({self.device_id}, [{start}, {end}))") + return self._history_repo.get_power_points(self.device_id, start, end) + + async def get_history(self, start: Timestamp, end: Timestamp) -> List[HomeLoadEnergyInterval]: + """Return 1-hour consumption intervals for this device in [start, end).""" + if self._logger: + self._logger.debug(f"DummyEnergyLoadHistoryProvider: get_history({self.device_id}, [{start}, {end}))") + power_points = await self.get_power_points(start, end) + return group_power_points_into_intervals(power_points, start=start, end=end) diff --git a/edge_mining/adapters/domain/home_load/history_providers/helpers.py b/edge_mining/adapters/domain/home_load/history_providers/helpers.py new file mode 100644 index 0000000..27fd983 --- /dev/null +++ b/edge_mining/adapters/domain/home_load/history_providers/helpers.py @@ -0,0 +1,60 @@ +"""Shared helpers for home load history provider adapters.""" + +from datetime import timedelta +from typing import List, Optional + +from edge_mining.domain.common import Timestamp, WattHours +from edge_mining.domain.home_load.value_objects import HomeLoadEnergyInterval, HomeLoadPowerPoint + + +def group_power_points_into_intervals( + power_points: List[HomeLoadPowerPoint], + start: Optional[Timestamp] = None, + end: Optional[Timestamp] = None, +) -> List[HomeLoadEnergyInterval]: + """Group power points into contiguous 1-hour intervals. + + Intervals walk forward from ``start`` (or first point) by 1-hour steps + up to ``end`` (or last point). Empty intervals contribute zero energy + so downstream consumers see a contiguous timeline. + """ + if not power_points and (start is None or end is None): + return [] + + sorted_points = sorted(power_points, key=lambda p: p.timestamp) + + if start is None: + start = sorted_points[0].timestamp + if end is None: + end = sorted_points[-1].timestamp + if start >= end: + raise ValueError("Start timestamp must be before end timestamp.") + + intervals: List[HomeLoadEnergyInterval] = [] + current_start = start + while current_start < end: + current_end = min(current_start + timedelta(hours=1), end) + + interval_points = [p for p in sorted_points if current_start <= p.timestamp < current_end] + + if interval_points: + intervals.append( + HomeLoadEnergyInterval.create_from_power_points( + start=current_start, + end=current_end, + power_points=interval_points, + ) + ) + else: + intervals.append( + HomeLoadEnergyInterval( + start=current_start, + end=current_end, + energy=WattHours(0.0), + power_points=[], + ) + ) + + current_start = current_end + + return intervals diff --git a/edge_mining/adapters/domain/home_load/history_providers/home_assistant_api_history.py b/edge_mining/adapters/domain/home_load/history_providers/home_assistant_api_history.py new file mode 100644 index 0000000..006c720 --- /dev/null +++ b/edge_mining/adapters/domain/home_load/history_providers/home_assistant_api_history.py @@ -0,0 +1,191 @@ +""" +Home Assistant API Energy Load History adapter (Implementation of Port) +for the energy home loads domain of Edge Mining Application. + +The adapter is device-scoped: each instance is bound at construction time to a +single ``LoadDevice`` via its ``device_id``. History is fetched from Home +Assistant and opportunistically cached into the ``EnergyLoadHistoryRepository`` +for that device. +""" + +from datetime import datetime, timedelta, timezone +from typing import List, Optional, cast + +from edge_mining.adapters.domain.home_load.history_providers.helpers import group_power_points_into_intervals +from edge_mining.adapters.infrastructure.homeassistant.homeassistant_api import ( + ServiceHomeAssistantAPI, +) +from edge_mining.adapters.infrastructure.homeassistant.models import EntityHistory +from edge_mining.adapters.infrastructure.homeassistant.utils import EntityState +from edge_mining.domain.common import EntityId, Timestamp, Watts +from edge_mining.domain.home_load.common import EnergyLoadHistoryProviderAdapter +from edge_mining.domain.home_load.entities import LoadDevice +from edge_mining.domain.home_load.exceptions import ( + EnergyLoadHistoryProviderConfigurationError, + EnergyLoadHistoryProviderError, +) +from edge_mining.domain.home_load.ports import EnergyLoadHistoryProviderPort, EnergyLoadHistoryRepository +from edge_mining.domain.home_load.value_objects import HomeLoadEnergyInterval, HomeLoadPowerPoint +from edge_mining.shared.adapter_configs.home_load import ( + EnergyLoadHistoryProviderHomeAssistantAPIConfig, +) +from edge_mining.shared.external_services.common import ExternalServiceAdapter +from edge_mining.shared.external_services.ports import ExternalServicePort +from edge_mining.shared.interfaces.config import Configuration +from edge_mining.shared.interfaces.factories import EnergyLoadHistoryAdapterFactory +from edge_mining.shared.logging.port import LoggerPort + + +class HomeAssistantAPIEnergyLoadHistoryProviderFactory(EnergyLoadHistoryAdapterFactory): + """Factory for ``HomeAssistantAPIEnergyLoadHistoryProvider`` instances. + + The infrastructure repository is injected at factory construction time + (one repo serves all devices). ``from_load_device`` binds the device-scope + before ``create`` is called. + """ + + def __init__(self, history_repo: EnergyLoadHistoryRepository): + self._history_repo = history_repo + self._load_device: Optional[LoadDevice] = None + + def from_load_device(self, load_device: LoadDevice) -> None: + """Bind the factory to the LoadDevice this adapter will serve.""" + self._load_device = load_device + + def create( + self, + config: Optional[Configuration], + logger: Optional[LoggerPort], + external_service: Optional[ExternalServicePort], + ) -> "HomeAssistantAPIEnergyLoadHistoryProvider": + """Build a device-scoped Home Assistant API history adapter.""" + if self._load_device is None: + raise EnergyLoadHistoryProviderConfigurationError( + "from_load_device(...) must be called before create(...)." + ) + + if not external_service: + raise EnergyLoadHistoryProviderError("External service is required for EnergyLoadHistoryProviderAdapter.") + + if external_service.external_service_type != ExternalServiceAdapter.HOME_ASSISTANT_API: + raise EnergyLoadHistoryProviderError("External service must be of type Home Assistant API") + + if not isinstance(config, EnergyLoadHistoryProviderHomeAssistantAPIConfig): + raise EnergyLoadHistoryProviderConfigurationError( + "Invalid configuration type for HomeAssistantAPI energy load history provider. " + "Expected EnergyLoadHistoryProviderHomeAssistantAPIConfig." + ) + + service_home_assistant_api = cast(ServiceHomeAssistantAPI, external_service) + + return HomeAssistantAPIEnergyLoadHistoryProvider( + device_id=self._load_device.id, + entity_power=config.entity_power, + home_assistant=service_home_assistant_api, + history_repo=self._history_repo, + logger=logger, + ) + + +class HomeAssistantAPIEnergyLoadHistoryProvider(EnergyLoadHistoryProviderPort): + """Fetches energy load history for one LoadDevice from a Home Assistant instance. + + Caches raw power points in the injected ``EnergyLoadHistoryRepository`` + (infrastructure dependency — not part of the port contract) to avoid + re-hitting Home Assistant for already-observed windows. + """ + + _CACHE_STALENESS = timedelta(minutes=5) + + def __init__( + self, + device_id: EntityId, + entity_power: str, + home_assistant: ServiceHomeAssistantAPI, + history_repo: EnergyLoadHistoryRepository, + logger: Optional[LoggerPort] = None, + ): + super().__init__( + device_id=device_id, + provider_type=EnergyLoadHistoryProviderAdapter.HOME_ASSISTANT_API, + ) + self._home_assistant = home_assistant + self._history_repo = history_repo + self._logger = logger + + if not entity_power or entity_power.strip() == "": + raise EnergyLoadHistoryProviderConfigurationError("Power entity must be provided and cannot be empty.") + self._entity_power = entity_power + + if self._logger: + self._logger.debug(f"HA history adapter bound to device {device_id} (entity='{entity_power}')") + + async def get_power_points(self, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + """Return power points for the bound device in [start, end). + + Hits the cache first; fetches missing or stale tail from Home Assistant. + """ + if start >= end: + return [] + + cached = self._history_repo.get_power_points(self.device_id, start, end) + + latest_cached: Optional[Timestamp] = max((p.timestamp for p in cached), default=None) + now_ts = Timestamp(datetime.now(timezone.utc)) + + if latest_cached is None: + fetched = await self._fetch_from_home_assistant(start, end) + if fetched: + self._history_repo.add_power_points(self.device_id, fetched) + return sorted(fetched, key=lambda p: p.timestamp) + + if now_ts - latest_cached > self._CACHE_STALENESS and latest_cached < end: + if self._logger: + self._logger.debug( + f"Cache tail stale for device {self.device_id}: " + f"latest={latest_cached}, now={now_ts}. Fetching incremental." + ) + tail = await self._fetch_from_home_assistant(latest_cached, end) + if tail: + self._history_repo.add_power_points(self.device_id, tail) + cached.extend(tail) + + return sorted(cached, key=lambda p: p.timestamp) + + async def get_history(self, start: Timestamp, end: Timestamp) -> List[HomeLoadEnergyInterval]: + """Return 1-hour consumption intervals for the bound device in [start, end).""" + if self._logger: + self._logger.debug(f"Computing 1h intervals for device {self.device_id} in [{start}, {end}).") + power_points = await self.get_power_points(start, end) + return group_power_points_into_intervals(power_points, start=start, end=end) + + async def _fetch_from_home_assistant(self, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + """Fetch raw power points from Home Assistant REST API.""" + entity_history: Optional[EntityHistory] = await self._home_assistant.get_entity_history( + self._entity_power, start, end + ) + if not entity_history: + if self._logger: + self._logger.error(f"No history data found for entity '{self._entity_power}'") + return [] + + points: List[HomeLoadPowerPoint] = [] + for raw in entity_history.history: + if raw.value is None or raw.value.lower() in ( + EntityState.UNAVAILABLE.value, + EntityState.UNKNOWN.value, + ): + if self._logger: + self._logger.error(f"Invalid power data point '{raw.value}'. Skipping.") + continue + + unit = raw.unit or "W" + parsed = self._home_assistant.parse_power(raw.value, unit, self._entity_power or "N/A") + if parsed is None: + if self._logger: + self._logger.error(f"Failed to parse power '{raw.value}' for '{self._entity_power}'. Skipping.") + continue + + points.append(HomeLoadPowerPoint(timestamp=Timestamp(raw.timestamp), power=Watts(parsed))) + + return points diff --git a/edge_mining/adapters/domain/home_load/providers/__init__.py b/edge_mining/adapters/domain/home_load/providers/__init__.py deleted file mode 100644 index 6734ea4..0000000 --- a/edge_mining/adapters/domain/home_load/providers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Collection of home load provider adapters.""" diff --git a/edge_mining/adapters/domain/home_load/providers/dummy.py b/edge_mining/adapters/domain/home_load/providers/dummy.py deleted file mode 100644 index 73fe744..0000000 --- a/edge_mining/adapters/domain/home_load/providers/dummy.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Dummy adapter (Implementation of Port) that simulates -the home loads forecast for Edge Mining Application -""" - -import random -from datetime import datetime, timedelta -from typing import Dict, Optional - -from edge_mining.domain.common import Timestamp, Watts -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter -from edge_mining.domain.home_load.ports import HomeForecastProviderPort -from edge_mining.domain.home_load.value_objects import ConsumptionForecast -from edge_mining.shared.logging.port import LoggerPort - - -class DummyHomeForecastProvider(HomeForecastProviderPort): - """Generates a very basic fake home load forecast.""" - - def __init__(self, load_power_max: float = 500.0, logger: Optional[LoggerPort] = None): - """Initializes the DummyHomeForecastProvider.""" - super().__init__(home_forecast_provider_type=HomeForecastProviderAdapter.DUMMY) - self.logger = logger - - self.load_power_max = load_power_max - # You can set default values or use the ones from settings if needed - - def get_home_consumption_forecast(self, hours_ahead: int = 3) -> Optional[ConsumptionForecast]: - """Get the home consumption forecast.""" - # Super simple: return a random average load expected soon for next hours_ahead hours. - if self.logger: - self.logger.debug( - f"DummyHomeForecastProvider: " - f"Generating home load forecast for {hours_ahead} hours ahead " - f"with max load {self.load_power_max} kWp" - ) - - now = datetime.now() - predictions: Dict[Timestamp, Watts] = {} - - # Average Watts expected for the next hours - # For simplicity, we just generate a random load value - # In a real scenario, this would be based on historical data, time of day, etc. - # Here we assume a random load between 200W and max load - avg_load = Watts(random.uniform(200, self.load_power_max)) - - for i in range(hours_ahead): # Forecast for next hours_ahead hours - future_time = now + timedelta(hours=i) - predicted_power = avg_load - predictions[Timestamp(future_time)] = predicted_power - - home_forecast = ConsumptionForecast(predicted_watts=predictions, generated_at=Timestamp(now)) - - if self.logger: - self.logger.debug( - f"DummyHomeForecastProvider: Estimated avg home load: {avg_load:.0f}W for next {hours_ahead} hours" - ) - return home_forecast diff --git a/edge_mining/adapters/domain/home_load/repositories.py b/edge_mining/adapters/domain/home_load/repositories.py index 2e68e6f..25540c6 100644 --- a/edge_mining/adapters/domain/home_load/repositories.py +++ b/edge_mining/adapters/domain/home_load/repositories.py @@ -4,82 +4,149 @@ import json import sqlite3 import uuid +from datetime import datetime, timezone from typing import Any, Dict, List, Optional -from sqlalchemy import select +from sqlalchemy import delete, func, insert, select -from edge_mining.adapters.domain.home_load.tables import home_forecast_providers_table, home_profiles_table +from edge_mining.adapters.domain.home_load.tables import ( + energy_load_forecast_providers_table, + energy_load_history_providers_table, + home_load_power_points_table, + home_profiles_table, + load_consumption_models_table, +) from edge_mining.adapters.infrastructure.persistence.sqlalchemy.base import BaseSQLAlchemyRepository from edge_mining.adapters.infrastructure.persistence.sqlite import BaseSqliteRepository -from edge_mining.domain.common import EntityId +from edge_mining.domain.common import EntityId, Timestamp, Watts from edge_mining.domain.exceptions import ConfigurationError from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter -from edge_mining.domain.home_load.entities import HomeForecastProvider, LoadDevice +from edge_mining.domain.home_load.common import ( + EnergyLoadForecastProviderAdapter, + EnergyLoadHistoryProviderAdapter, + LoadDeviceCategory, +) +from edge_mining.domain.home_load.entities import ( + EnergyLoadForecastProvider, + EnergyLoadHistoryProvider, + LoadConsumptionModel, + LoadDevice, +) from edge_mining.domain.home_load.exceptions import ( - HomeForecastProviderAlreadyExistsError, - HomeForecastProviderError, - HomeForecastProviderNotFoundError, - HomeForecastProviderConfigurationError, + EnergyLoadForecastProviderAlreadyExistsError, + EnergyLoadForecastProviderConfigurationError, + EnergyLoadForecastProviderError, + EnergyLoadForecastProviderNotFoundError, + EnergyLoadHistoryProviderAlreadyExistsError, + EnergyLoadHistoryProviderConfigurationError, + EnergyLoadHistoryProviderError, + EnergyLoadHistoryProviderNotFoundError, ) from edge_mining.domain.home_load.ports import ( - HomeForecastProviderRepository, + EnergyLoadForecastProviderRepository, + EnergyLoadHistoryProviderRepository, + EnergyLoadHistoryRepository, HomeLoadsProfileRepository, + LoadConsumptionModelRepository, ) +from edge_mining.domain.home_load.value_objects import HomeLoadPowerPoint from edge_mining.shared.adapter_maps.home_load import ( - HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP, + ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP, + ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP, ) -from edge_mining.shared.interfaces.config import HomeForecastProviderConfig +from edge_mining.shared.interfaces.config import EnergyLoadForecastProviderConfig, EnergyLoadHistoryProviderConfig + + +# --- HomeLoadsProfile Repositories --- + + +def _device_to_dict(device: LoadDevice) -> Dict[str, Any]: + return { + "id": str(device.id), + "name": device.name, + "category": device.category.value, + "enabled": device.enabled, + "energy_load_forecast_provider_id": ( + str(device.energy_load_forecast_provider_id) if device.energy_load_forecast_provider_id else None + ), + "energy_load_history_provider_id": ( + str(device.energy_load_history_provider_id) if device.energy_load_history_provider_id else None + ), + } + -# Simple In-Memory implementation for testing and basic use +def _dict_to_device(data: Dict[str, Any]) -> LoadDevice: + forecast_id = data.get("energy_load_forecast_provider_id") + history_id = data.get("energy_load_history_provider_id") + return LoadDevice( + id=EntityId(uuid.UUID(data["id"])), + name=data["name"], + category=LoadDeviceCategory(data["category"]), + enabled=bool(data.get("enabled", True)), + energy_load_forecast_provider_id=EntityId(uuid.UUID(forecast_id)) if forecast_id else None, + energy_load_history_provider_id=EntityId(uuid.UUID(history_id)) if history_id else None, + ) class InMemoryHomeLoadsProfileRepository(HomeLoadsProfileRepository): - """In-Memory implementation for the Home Loads Profile Repository.""" + """In-memory implementation for the Home Loads Profile Repository.""" - def __init__(self, initial_profile: Optional[HomeLoadsProfile] = None): - self._profile: Optional[HomeLoadsProfile] = copy.deepcopy(initial_profile) + def __init__(self, initial_profiles: Optional[List[HomeLoadsProfile]] = None): + self._profiles: Dict[EntityId, HomeLoadsProfile] = {} + if initial_profiles: + for profile in initial_profiles: + self._profiles[profile.id] = copy.deepcopy(profile) - def get_profile(self) -> Optional[HomeLoadsProfile]: - return copy.deepcopy(self._profile) + def add(self, profile: HomeLoadsProfile) -> None: + self._profiles[profile.id] = copy.deepcopy(profile) - def save_profile(self, profile: HomeLoadsProfile) -> None: - self._profile = copy.deepcopy(profile) + def get_by_id(self, profile_id: EntityId) -> Optional[HomeLoadsProfile]: + profile = self._profiles.get(profile_id) + return copy.deepcopy(profile) if profile else None + + def get_all(self) -> List[HomeLoadsProfile]: + return [copy.deepcopy(p) for p in self._profiles.values()] + + def update(self, profile: HomeLoadsProfile) -> None: + self._profiles[profile.id] = copy.deepcopy(profile) + + def remove(self, profile_id: EntityId) -> None: + self._profiles.pop(profile_id, None) + + def get_by_energy_load_forecast_provider_id(self, provider_id: EntityId) -> List[HomeLoadsProfile]: + return [ + copy.deepcopy(profile) + for profile in self._profiles.values() + if any(device.energy_load_forecast_provider_id == provider_id for device in profile.devices) + ] class SqliteHomeLoadsProfileRepository(HomeLoadsProfileRepository): """SQLite implementation for the Home Loads Profile Repository.""" - # fixed UUID for the default profile - _DEFAULT_PROFILE_UUID = uuid.UUID("00000000-0000-0000-0000-000000000001") - def __init__(self, db: BaseSqliteRepository): self._db = db self.logger = db.logger - self._create_tables() def _create_tables(self): - """Create the necessary tables for the Home Load domain if they do not exist.""" self.logger.debug(f"Ensuring SQLite tables exist for Home Loads Profile Repository in {self._db.db_path}...") sql_statements = [ """ CREATE TABLE IF NOT EXISTS home_profiles ( - id TEXT PRIMARY KEY, -- e.g., fixed UUID for default profile + id TEXT PRIMARY KEY, name TEXT NOT NULL, - devices_json TEXT -- JSON Dict[EntityId_str, LoadDevice_dict] + devices_json TEXT -- JSON list of LoadDevice dicts ); """ ] conn = self._db.get_connection() - try: with conn: cursor = conn.cursor() for statement in sql_statements: cursor.execute(statement) - self.logger.debug("Home Loads Profile tables checked/created successfully.") except sqlite3.Error as e: self.logger.error(f"Error creating SQLite tables: {e}") @@ -88,104 +155,140 @@ def _create_tables(self): if conn: conn.close() - def _device_to_dict(self, device: LoadDevice) -> Dict[str, Any]: - return {"id": str(device.id), "name": device.name, "type": device.type} - - def _dict_to_device(self, data: Dict[str, Any]) -> LoadDevice: - """Convert a dictionary to a LoadDevice.""" - return LoadDevice(id=EntityId(uuid.UUID(data["id"])), name=data["name"], type=data["type"]) - def _row_to_profile(self, row: sqlite3.Row) -> Optional[HomeLoadsProfile]: - """Convert a row to a HomeLoadsProfile.""" if not row: return None try: - devices_data: Dict = json.loads(row["devices_json"] or "{}") - devices = { - EntityId(uuid.UUID(id_str)): self._dict_to_device(dev_dict) for id_str, dev_dict in devices_data.items() - } - return HomeLoadsProfile(id=row["id"], name=row["name"], devices=devices) # UUID + devices_data: List = json.loads(row["devices_json"] or "[]") + devices = [_dict_to_device(dev) for dev in devices_data if isinstance(dev, dict)] + return HomeLoadsProfile(id=EntityId(uuid.UUID(row["id"])), name=row["name"], devices=devices) except (json.JSONDecodeError, ValueError, KeyError, TypeError) as e: - self.logger.error(f"Error deserializing HomeLoadsProfile from DB line: {dict(row)}. Error: {e}") + self.logger.error(f"Error deserializing HomeLoadsProfile from DB row: {dict(row)}. Error: {e}") return None - def get_profile(self) -> Optional[HomeLoadsProfile]: - """Get the home load profile from SQLite.""" - self.logger.debug("Getting home load profile from SQLite.") + def add(self, profile: HomeLoadsProfile) -> None: + self.logger.debug(f"Adding home loads profile '{profile.name}' ({profile.id}) to SQLite.") + sql = "INSERT INTO home_profiles (id, name, devices_json) VALUES (?, ?, ?)" + conn = self._db.get_connection() + try: + devices_json = json.dumps([_device_to_dict(dev) for dev in profile.devices]) + with conn: + conn.execute(sql, (str(profile.id), profile.name, devices_json)) + except sqlite3.Error as e: + self.logger.error(f"SQLite error adding profile {profile.id}: {e}") + raise ConfigurationError(f"DB error adding profile: {e}") from e + finally: + if conn: + conn.close() + + def get_by_id(self, profile_id: EntityId) -> Optional[HomeLoadsProfile]: sql = "SELECT * FROM home_profiles WHERE id = ?" conn = self._db.get_connection() try: cursor = conn.cursor() - cursor.execute(sql, (self._DEFAULT_PROFILE_UUID,)) + cursor.execute(sql, (str(profile_id),)) row = cursor.fetchone() - if row: - return self._row_to_profile(row) - else: - self.logger.info("No home load profile found in DB, returning None.") - return None + return self._row_to_profile(row) if row else None except sqlite3.Error as e: - self.logger.error(f"SQLite error getting home profile: {e}") + self.logger.error(f"SQLite error getting profile {profile_id}: {e}") return None finally: if conn: conn.close() - def save_profile(self, profile: HomeLoadsProfile) -> None: - """Save the home load profile to SQLite.""" - self.logger.debug(f"Saving home load profile '{profile.name}' to SQLite.") - sql = "INSERT OR REPLACE INTO home_profiles (id, name, devices_json) VALUES (?, ?, ?)" + def get_all(self) -> List[HomeLoadsProfile]: + sql = "SELECT * FROM home_profiles" + conn = self._db.get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql) + rows = cursor.fetchall() + profiles: List[HomeLoadsProfile] = [] + for row in rows: + profile = self._row_to_profile(row) + if profile: + profiles.append(profile) + return profiles + except sqlite3.Error as e: + self.logger.error(f"SQLite error getting all profiles: {e}") + return [] + finally: + if conn: + conn.close() + + def update(self, profile: HomeLoadsProfile) -> None: + sql = "UPDATE home_profiles SET name = ?, devices_json = ? WHERE id = ?" conn = self._db.get_connection() try: - # Serialize the dictionary of devices - devices_json = json.dumps({str(id): self._device_to_dict(dev) for id, dev in profile.devices.items()}) + devices_json = json.dumps([_device_to_dict(dev) for dev in profile.devices]) with conn: - # Always use the fixed UUID for the default profile - # default - conn.execute( - sql, - (self._DEFAULT_PROFILE_UUID, profile.name, devices_json), - ) + conn.execute(sql, (profile.name, devices_json, str(profile.id))) + except sqlite3.Error as e: + self.logger.error(f"SQLite error updating profile {profile.id}: {e}") + raise ConfigurationError(f"DB error updating profile: {e}") from e + finally: + if conn: + conn.close() + + def remove(self, profile_id: EntityId) -> None: + sql = "DELETE FROM home_profiles WHERE id = ?" + conn = self._db.get_connection() + try: + with conn: + conn.execute(sql, (str(profile_id),)) except sqlite3.Error as e: - self.logger.error(f"SQLite error saving home profile: {e}") - raise ConfigurationError(f"DB error saving home profile: {e}") from e + self.logger.error(f"SQLite error removing profile {profile_id}: {e}") + raise ConfigurationError(f"DB error removing profile: {e}") from e finally: if conn: conn.close() + def get_by_energy_load_forecast_provider_id(self, provider_id: EntityId) -> List[HomeLoadsProfile]: + return [ + profile + for profile in self.get_all() + if any(device.energy_load_forecast_provider_id == provider_id for device in profile.devices) + ] + + +# --- EnergyLoadForecastProvider Repositories --- -class InMemoryHomeForecastProviderRepository(HomeForecastProviderRepository): - """In-memory implementation of HomeForecastProviderRepository for testing purposes.""" + +class InMemoryEnergyLoadForecastProviderRepository(EnergyLoadForecastProviderRepository): + """In-memory implementation of EnergyLoadForecastProviderRepository for testing purposes.""" def __init__(self): - self._home_forecast_providers: List[HomeForecastProvider] = [] + self._energy_load_forecast_providers: List[EnergyLoadForecastProvider] = [] - def add(self, home_forecast_provider: HomeForecastProvider) -> None: - self._home_forecast_providers.append(home_forecast_provider) + def add(self, energy_load_forecast_provider: EnergyLoadForecastProvider) -> None: + self._energy_load_forecast_providers.append(energy_load_forecast_provider) - def get_by_id(self, home_forecast_provider_id: EntityId) -> Optional[HomeForecastProvider]: - for home_forecast_provider in self._home_forecast_providers: - if home_forecast_provider.id == home_forecast_provider_id: - return home_forecast_provider + def get_by_id(self, energy_load_forecast_provider_id: EntityId) -> Optional[EnergyLoadForecastProvider]: + for energy_load_forecast_provider in self._energy_load_forecast_providers: + if energy_load_forecast_provider.id == energy_load_forecast_provider_id: + return energy_load_forecast_provider return None - def get_all(self) -> List[HomeForecastProvider]: - return self._home_forecast_providers + def get_all(self) -> List[EnergyLoadForecastProvider]: + return self._energy_load_forecast_providers - def update(self, home_forecast_provider: HomeForecastProvider) -> None: - for i, existing_home_forecast_provider in enumerate(self._home_forecast_providers): - if existing_home_forecast_provider.id == home_forecast_provider.id: - self._home_forecast_providers[i] = home_forecast_provider + def update(self, energy_load_forecast_provider: EnergyLoadForecastProvider) -> None: + for i, existing_provider in enumerate(self._energy_load_forecast_providers): + if existing_provider.id == energy_load_forecast_provider.id: + self._energy_load_forecast_providers[i] = energy_load_forecast_provider return - def remove(self, home_forecast_provider_id: EntityId) -> None: - self._home_forecast_providers = [n for n in self._home_forecast_providers if n.id != home_forecast_provider_id] + def remove(self, energy_load_forecast_provider_id: EntityId) -> None: + self._energy_load_forecast_providers = [ + n for n in self._energy_load_forecast_providers if n.id != energy_load_forecast_provider_id + ] - def get_by_external_service_id(self, external_service_id: EntityId) -> List[HomeForecastProvider]: - """Retrieve all home forecast providers linked to a specific external service.""" + def get_by_external_service_id(self, external_service_id: EntityId) -> List[EnergyLoadForecastProvider]: + """Retrieve all energy load forecast providers linked to a specific external service.""" return ( [ provider - for provider in self._home_forecast_providers + for provider in self._energy_load_forecast_providers if provider.external_service_id == external_service_id ] if external_service_id @@ -193,8 +296,8 @@ def get_by_external_service_id(self, external_service_id: EntityId) -> List[Home ) -class SqliteHomeForecastProviderRepository(HomeForecastProviderRepository): - """SQLite implementation of HomeForecastProviderRepository.""" +class SqliteEnergyLoadForecastProviderRepository(EnergyLoadForecastProviderRepository): + """SQLite implementation of EnergyLoadForecastProviderRepository.""" def __init__(self, db: BaseSqliteRepository): self._db = db @@ -203,19 +306,17 @@ def __init__(self, db: BaseSqliteRepository): self._create_tables() def _create_tables(self): - """Create the necessary table for the Home Forecast Provider if it does not exist.""" self.logger.debug( - f"Ensuring SQLite tables exist for Home Forecast Provider Repository in {self._db.db_path}..." + f"Ensuring SQLite tables exist for Energy Load Forecast Provider Repository in {self._db.db_path}..." ) sql_statements = [ """ - CREATE TABLE IF NOT EXISTS home_forecast_providers ( + CREATE TABLE IF NOT EXISTS energy_load_forecast_providers ( id TEXT PRIMARY KEY, name TEXT NOT NULL, adapter_type TEXT NOT NULL, - config TEXT, -- JSON object of config - external_service_id TEXT -- Optional ID for external service integration - + config TEXT, + external_service_id TEXT ); """ ] @@ -225,8 +326,6 @@ def _create_tables(self): cursor = conn.cursor() for statement in sql_statements: cursor.execute(statement) - - self.logger.debug("Home Forecast providers tables checked/created successfully.") except sqlite3.Error as e: self.logger.error(f"Error creating SQLite tables: {e}") raise ConfigurationError(f"DB error creating tables: {e}") from e @@ -235,210 +334,202 @@ def _create_tables(self): conn.close() def _deserialize_config( - self, adapter_type: HomeForecastProviderAdapter, config_json: str - ) -> HomeForecastProviderConfig: - """Deserialize a JSON string into HomeForecastProviderConfig object.""" + self, adapter_type: EnergyLoadForecastProviderAdapter, config_json: str + ) -> EnergyLoadForecastProviderConfig: data: dict = json.loads(config_json) - if adapter_type not in HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP: - raise HomeForecastProviderNotFoundError( - f"Error reading HomeForecastProvider configuration. Invalid type '{adapter_type}'" + if adapter_type not in ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP: + raise EnergyLoadForecastProviderNotFoundError( + f"Error reading EnergyLoadForecastProvider configuration. Invalid type '{adapter_type}'" ) - config_class: Optional[type[HomeForecastProviderConfig]] = HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get( - adapter_type + config_class: Optional[type[EnergyLoadForecastProviderConfig]] = ( + ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get(adapter_type) ) if not config_class: - raise HomeForecastProviderNotFoundError( - f"Error creating HomeForecastProviderConfig configuration. Type '{adapter_type}'" + raise EnergyLoadForecastProviderNotFoundError( + f"Error creating EnergyLoadForecastProviderConfig configuration. Type '{adapter_type}'" ) config_instance = config_class.from_dict(data) - if not isinstance(config_instance, HomeForecastProviderConfig): - raise HomeForecastProviderConfigurationError( - f"Deserialized config is not of type HomeForecastProviderConfig for adapter type {adapter_type}." + if not isinstance(config_instance, EnergyLoadForecastProviderConfig): + raise EnergyLoadForecastProviderConfigurationError( + f"Deserialized config is not of type EnergyLoadForecastProviderConfig " + f"for adapter type {adapter_type}." ) return config_instance - def _row_to_home_forecast_provider(self, row: sqlite3.Row) -> Optional[HomeForecastProvider]: - """Deserialize a row from the database into a HomeForecastProvider object.""" + def _row_to_energy_load_forecast_provider(self, row: sqlite3.Row) -> Optional[EnergyLoadForecastProvider]: if not row: return None try: - home_forecast_provider_type = HomeForecastProviderAdapter(row["adapter_type"]) + provider_type = EnergyLoadForecastProviderAdapter(row["adapter_type"]) + config = self._deserialize_config(provider_type, row["config"]) - # Deserialize the config from the database row - config = self._deserialize_config(home_forecast_provider_type, row["config"]) - - return HomeForecastProvider( + return EnergyLoadForecastProvider( id=EntityId(row["id"]), name=row["name"], - adapter_type=home_forecast_provider_type, + adapter_type=provider_type, config=config, external_service_id=(EntityId(row["external_service_id"]) if row["external_service_id"] else None), ) except (ValueError, KeyError) as e: - self.logger.error(f"Error deserializing HomeForecastProvider from DB row: {row}. Error: {e}") + self.logger.error(f"Error deserializing EnergyLoadForecastProvider from DB row: {row}. Error: {e}") return None - def add(self, home_forecast_provider: HomeForecastProvider) -> None: - """Add a new home forecast provider to the repository.""" - self.logger.debug(f"Adding forecast provider {home_forecast_provider.id} to SQLite repository.") + def add(self, energy_load_forecast_provider: EnergyLoadForecastProvider) -> None: + self.logger.debug(f"Adding forecast provider {energy_load_forecast_provider.id} to SQLite repository.") sql = """ - INSERT INTO home_forecast_providers (id, name, adapter_type, config, external_service_id) + INSERT INTO energy_load_forecast_providers (id, name, adapter_type, config, external_service_id) VALUES (?, ?, ?, ?, ?); """ conn = self._db.get_connection() try: - # Serialize config to JSON for storage config_json: str = "" - if home_forecast_provider.config: - config_json = json.dumps(home_forecast_provider.config.to_dict()) + if energy_load_forecast_provider.config: + config_json = json.dumps(energy_load_forecast_provider.config.to_dict()) with conn: cursor = conn.cursor() cursor.execute( sql, ( - home_forecast_provider.id, - home_forecast_provider.name, - home_forecast_provider.adapter_type.value, + energy_load_forecast_provider.id, + energy_load_forecast_provider.name, + energy_load_forecast_provider.adapter_type.value, config_json, - home_forecast_provider.external_service_id, + energy_load_forecast_provider.external_service_id, ), ) except sqlite3.IntegrityError as e: - self.logger.error(f"Integrity error adding home forecast provider {home_forecast_provider.id}: {e}") - # Could mean that the ID already exists - raise HomeForecastProviderAlreadyExistsError( - f"Home forecast provider with ID {home_forecast_provider.id} " + self.logger.error( + f"Integrity error adding energy load forecast provider {energy_load_forecast_provider.id}: {e}" + ) + raise EnergyLoadForecastProviderAlreadyExistsError( + f"Energy load forecast provider with ID {energy_load_forecast_provider.id} " f"already exists or constraint violation: {e}" ) from e except sqlite3.Error as e: - self.logger.error(f"SQLite error adding home forecast provider {home_forecast_provider.id}: {e}") - raise HomeForecastProviderError(f"DB error adding home forecast provider: {e}") from e + self.logger.error( + f"SQLite error adding energy load forecast provider {energy_load_forecast_provider.id}: {e}" + ) + raise EnergyLoadForecastProviderError(f"DB error adding energy load forecast provider: {e}") from e finally: if conn: conn.close() - def get_by_id(self, home_forecast_provider_id: EntityId) -> Optional[HomeForecastProvider]: - """Retrieve an home forecast provider by its ID.""" - self.logger.debug(f"Retrieving home forecast provider {home_forecast_provider_id} from SQLite repository.") - sql = "SELECT * FROM home_forecast_providers WHERE id = ?;" + def get_by_id(self, energy_load_forecast_provider_id: EntityId) -> Optional[EnergyLoadForecastProvider]: + sql = "SELECT * FROM energy_load_forecast_providers WHERE id = ?;" conn = self._db.get_connection() try: cursor = conn.cursor() - cursor.execute(sql, (home_forecast_provider_id,)) + cursor.execute(sql, (energy_load_forecast_provider_id,)) row = cursor.fetchone() - return self._row_to_home_forecast_provider(row) + return self._row_to_energy_load_forecast_provider(row) except sqlite3.Error as e: - self.logger.error(f"SQLite error retrieving home forecast provider {home_forecast_provider_id}: {e}") - raise HomeForecastProviderNotFoundError(f"DB error retrieving home forecast provider: {e}") from e + self.logger.error( + f"SQLite error retrieving energy load forecast provider {energy_load_forecast_provider_id}: {e}" + ) + raise EnergyLoadForecastProviderNotFoundError( + f"DB error retrieving energy load forecast provider: {e}" + ) from e finally: if conn: conn.close() - def get_all(self) -> List[HomeForecastProvider]: - """Retrieve all home forecast providers from the repository.""" - self.logger.debug("Retrieving all home forecast providers from SQLite repository.") - sql = "SELECT * FROM home_forecast_providers;" + def get_all(self) -> List[EnergyLoadForecastProvider]: + sql = "SELECT * FROM energy_load_forecast_providers;" conn = self._db.get_connection() try: cursor = conn.cursor() cursor.execute(sql) rows = cursor.fetchall() - home_forecast_providers = [] + energy_load_forecast_providers = [] for row in rows: - home_forecast_provider = self._row_to_home_forecast_provider(row) - if home_forecast_provider: - home_forecast_providers.append(home_forecast_provider) + provider = self._row_to_energy_load_forecast_provider(row) + if provider: + energy_load_forecast_providers.append(provider) except sqlite3.Error as e: - self.logger.error(f"SQLite error retrieving all home forecast providers: {e}") + self.logger.error(f"SQLite error retrieving all energy load forecast providers: {e}") return [] finally: if conn: conn.close() - return home_forecast_providers + return energy_load_forecast_providers - def update(self, home_forecast_provider: HomeForecastProvider) -> None: - """Update an existing home forecast provider in the repository.""" - self.logger.debug(f"Updating home forecast provider {home_forecast_provider.id} in SQLite repository.") + def update(self, energy_load_forecast_provider: EnergyLoadForecastProvider) -> None: sql = """ - UPDATE home_forecast_providers + UPDATE energy_load_forecast_providers SET name = ?, adapter_type = ?, config = ?, external_service_id = ? WHERE id = ?; """ conn = self._db.get_connection() try: - # Serialize config to JSON for storage - config_json = json.dumps(home_forecast_provider.config) + config_json = json.dumps(energy_load_forecast_provider.config) with conn: cursor = conn.cursor() cursor.execute( sql, ( - home_forecast_provider.name, - home_forecast_provider.adapter_type.value, + energy_load_forecast_provider.name, + energy_load_forecast_provider.adapter_type.value, config_json, - home_forecast_provider.external_service_id, - home_forecast_provider.id, + energy_load_forecast_provider.external_service_id, + energy_load_forecast_provider.id, ), ) if cursor.rowcount == 0: - raise HomeForecastProviderNotFoundError( - f"Home Forecast Provider with ID {home_forecast_provider.id} not found." + raise EnergyLoadForecastProviderNotFoundError( + f"Energy Load Forecast Provider with ID {energy_load_forecast_provider.id} not found." ) except sqlite3.Error as e: - self.logger.error(f"SQLite error updating home forecast provider {home_forecast_provider.id}: {e}") - raise HomeForecastProviderError(f"DB error updating home forecast provider: {e}") from e + self.logger.error( + f"SQLite error updating energy load forecast provider {energy_load_forecast_provider.id}: {e}" + ) + raise EnergyLoadForecastProviderError(f"DB error updating energy load forecast provider: {e}") from e finally: if conn: conn.close() - def remove(self, home_forecast_provider_id: EntityId) -> None: - """Remove an home forecast provider from the repository.""" - self.logger.debug(f"Removing forecast provider {home_forecast_provider_id} from SQLite repository.") - sql = "DELETE FROM home_forecast_providers WHERE id = ?;" + def remove(self, energy_load_forecast_provider_id: EntityId) -> None: + sql = "DELETE FROM energy_load_forecast_providers WHERE id = ?;" conn = self._db.get_connection() try: with conn: cursor = conn.cursor() - cursor.execute(sql, (home_forecast_provider_id,)) + cursor.execute(sql, (energy_load_forecast_provider_id,)) if cursor.rowcount == 0: self.logger.warning( - f"Attempted to remove non-existent home forecast provider {home_forecast_provider_id}." + f"Attempted to remove non-existent energy load forecast provider " + f"{energy_load_forecast_provider_id}." ) - # There is no need to raise an exception here, removing a - # non-existent is idempotent. except sqlite3.Error as e: - self.logger.error(f"SQLite error removing home forecast provider {home_forecast_provider_id}: {e}") - raise HomeForecastProviderError(f"DB error removing home forecast provider: {e}") from e + self.logger.error( + f"SQLite error removing energy load forecast provider {energy_load_forecast_provider_id}: {e}" + ) + raise EnergyLoadForecastProviderError(f"DB error removing energy load forecast provider: {e}") from e finally: if conn: conn.close() - def get_by_external_service_id(self, external_service_id: EntityId) -> List[HomeForecastProvider]: - """Retrieve all home forecast providers linked to a specific external service.""" - self.logger.debug( - "Retrieving home forecast providers linked to external service " - f"{external_service_id} from SQLite repository." - ) - sql = "SELECT * FROM home_forecast_providers WHERE external_service_id = ?;" + def get_by_external_service_id(self, external_service_id: EntityId) -> List[EnergyLoadForecastProvider]: + sql = "SELECT * FROM energy_load_forecast_providers WHERE external_service_id = ?;" conn = self._db.get_connection() try: cursor = conn.cursor() cursor.execute(sql, (external_service_id,)) rows = cursor.fetchall() - home_forecast_providers = [] + energy_load_forecast_providers = [] for row in rows: - home_forecast_provider = self._row_to_home_forecast_provider(row) - if home_forecast_provider: - home_forecast_providers.append(home_forecast_provider) - return home_forecast_providers + provider = self._row_to_energy_load_forecast_provider(row) + if provider: + energy_load_forecast_providers.append(provider) + return energy_load_forecast_providers except sqlite3.Error as e: self.logger.error( - f"SQLite error retrieving home forecast providers by external service ID {external_service_id}: {e}" + f"SQLite error retrieving energy load forecast providers by external service ID " + f"{external_service_id}: {e}" ) return [] finally: @@ -446,85 +537,67 @@ def get_by_external_service_id(self, external_service_id: EntityId) -> List[Home conn.close() -# SQLAlchemy implementation - - -class SqlAlchemyHomeForecastProviderRepository(HomeForecastProviderRepository): - """SQLAlchemy implementation of HomeForecastProviderRepository. +# --- SQLAlchemy implementations --- - This repository works directly with the imperatively mapped HomeForecastProvider domain entity. - The config field is automatically converted between HomeForecastProviderConfig objects and JSON - strings by the custom TypeDecorator and event listener defined in tables.py. - Args: - db: BaseSQLAlchemyRepository instance for database operations - """ +class SqlAlchemyEnergyLoadForecastProviderRepository(EnergyLoadForecastProviderRepository): + """SQLAlchemy implementation of EnergyLoadForecastProviderRepository.""" def __init__(self, db: BaseSQLAlchemyRepository): - """Initialize repository with database instance. - - Args: - db: BaseSQLAlchemyRepository instance - """ self._db = db self.logger = db.logger - def add(self, home_forecast_provider: HomeForecastProvider) -> None: - """Add a home forecast provider to the repository.""" + def add(self, energy_load_forecast_provider: EnergyLoadForecastProvider) -> None: session = self._db.get_session() try: - session.add(home_forecast_provider) + session.add(energy_load_forecast_provider) session.commit() finally: session.close() - def get_by_id(self, home_forecast_provider_id: EntityId) -> Optional[HomeForecastProvider]: - """Get a home forecast provider by ID.""" + def get_by_id(self, energy_load_forecast_provider_id: EntityId) -> Optional[EnergyLoadForecastProvider]: session = self._db.get_session() try: - stmt = select(HomeForecastProvider).where( - home_forecast_providers_table.c.id == str(home_forecast_provider_id) + stmt = select(EnergyLoadForecastProvider).where( + energy_load_forecast_providers_table.c.id == str(energy_load_forecast_provider_id) ) entity = session.execute(stmt).scalar_one_or_none() return entity finally: session.close() - def get_all(self) -> List[HomeForecastProvider]: - """Get all home forecast providers.""" + def get_all(self) -> List[EnergyLoadForecastProvider]: session = self._db.get_session() try: - stmt = select(HomeForecastProvider) + stmt = select(EnergyLoadForecastProvider) entities = session.execute(stmt).scalars().all() return list(entities) finally: session.close() - def update(self, home_forecast_provider: HomeForecastProvider) -> None: - """Update a home forecast provider.""" + def update(self, energy_load_forecast_provider: EnergyLoadForecastProvider) -> None: session = self._db.get_session() try: - stmt = select(HomeForecastProvider).where( - home_forecast_providers_table.c.id == str(home_forecast_provider.id) + stmt = select(EnergyLoadForecastProvider).where( + energy_load_forecast_providers_table.c.id == str(energy_load_forecast_provider.id) ) existing_entity = session.execute(stmt).scalar_one_or_none() if existing_entity: - existing_entity.name = home_forecast_provider.name - existing_entity.adapter_type = home_forecast_provider.adapter_type - existing_entity.config = home_forecast_provider.config - existing_entity.external_service_id = home_forecast_provider.external_service_id + existing_entity.name = energy_load_forecast_provider.name + existing_entity.adapter_type = energy_load_forecast_provider.adapter_type + existing_entity.config = energy_load_forecast_provider.config + existing_entity.external_service_id = energy_load_forecast_provider.external_service_id session.commit() finally: session.close() - def remove(self, home_forecast_provider_id: EntityId) -> None: - """Remove a home forecast provider by ID.""" + def remove(self, energy_load_forecast_provider_id: EntityId) -> None: session = self._db.get_session() try: - stmt = select(HomeForecastProvider).where( - home_forecast_providers_table.c.id == str(home_forecast_provider_id) + stmt = select(EnergyLoadForecastProvider).where( + energy_load_forecast_providers_table.c.id == str(energy_load_forecast_provider_id) ) entity = session.execute(stmt).scalar_one_or_none() @@ -534,12 +607,11 @@ def remove(self, home_forecast_provider_id: EntityId) -> None: finally: session.close() - def get_by_external_service_id(self, external_service_id: EntityId) -> List[HomeForecastProvider]: - """Get home forecast providers by external service ID.""" + def get_by_external_service_id(self, external_service_id: EntityId) -> List[EnergyLoadForecastProvider]: session = self._db.get_session() try: - stmt = select(HomeForecastProvider).where( - home_forecast_providers_table.c.external_service_id == str(external_service_id) + stmt = select(EnergyLoadForecastProvider).where( + energy_load_forecast_providers_table.c.external_service_id == str(external_service_id) ) entities = session.execute(stmt).scalars().all() return list(entities) @@ -548,57 +620,1069 @@ def get_by_external_service_id(self, external_service_id: EntityId) -> List[Home class SqlAlchemyHomeLoadsProfileRepository(HomeLoadsProfileRepository): - """SQLAlchemy implementation of the HomeLoadsProfileRepository. - - This repository works directly with the imperatively mapped HomeLoadsProfile aggregate root. - The devices field is automatically converted between Dict[EntityId, LoadDevice] and JSON - by the custom TypeDecorator and event listener defined in tables.py. - - Args: - db: BaseSQLAlchemyRepository instance for database operations - """ - - # fixed UUID for the default profile - _DEFAULT_PROFILE_UUID = uuid.UUID("00000000-0000-0000-0000-000000000001") + """SQLAlchemy implementation of the HomeLoadsProfileRepository.""" def __init__(self, db: BaseSQLAlchemyRepository): - """Initialize repository with database instance. - - Args: - db: BaseSQLAlchemyRepository instance - """ self._db = db self.logger = db.logger - def get_profile(self) -> Optional[HomeLoadsProfile]: - """Get the home load profile from the database.""" + def add(self, profile: HomeLoadsProfile) -> None: + session = self._db.get_session() + try: + session.add(profile) + session.commit() + finally: + session.close() + + def get_by_id(self, profile_id: EntityId) -> Optional[HomeLoadsProfile]: session = self._db.get_session() try: - stmt = select(HomeLoadsProfile).where(home_profiles_table.c.id == str(self._DEFAULT_PROFILE_UUID)) + stmt = select(HomeLoadsProfile).where(home_profiles_table.c.id == str(profile_id)) entity = session.execute(stmt).scalar_one_or_none() return entity finally: session.close() - def save_profile(self, profile: HomeLoadsProfile) -> None: - """Save the home load profile to the database.""" + def get_all(self) -> List[HomeLoadsProfile]: + session = self._db.get_session() + try: + stmt = select(HomeLoadsProfile) + entities = session.execute(stmt).scalars().all() + return list(entities) + finally: + session.close() + + def update(self, profile: HomeLoadsProfile) -> None: session = self._db.get_session() try: - # Check if profile already exists - stmt = select(HomeLoadsProfile).where(home_profiles_table.c.id == str(self._DEFAULT_PROFILE_UUID)) + stmt = select(HomeLoadsProfile).where(home_profiles_table.c.id == str(profile.id)) existing_entity = session.execute(stmt).scalar_one_or_none() if existing_entity: - # Update existing profile existing_entity.name = profile.name existing_entity.devices = profile.devices session.commit() + finally: + session.close() + + def remove(self, profile_id: EntityId) -> None: + session = self._db.get_session() + try: + stmt = select(HomeLoadsProfile).where(home_profiles_table.c.id == str(profile_id)) + entity = session.execute(stmt).scalar_one_or_none() + if entity: + session.delete(entity) + session.commit() + finally: + session.close() + + def get_by_energy_load_forecast_provider_id(self, provider_id: EntityId) -> List[HomeLoadsProfile]: + return [ + profile + for profile in self.get_all() + if any(device.energy_load_forecast_provider_id == provider_id for device in profile.devices) + ] + + +# --- EnergyLoadHistory (per-device power-point time series) Repositories --- + + +class InMemoryEnergyLoadHistoryRepository(EnergyLoadHistoryRepository): + """In-memory power-point store, indexed by device and kept sorted by timestamp.""" + + def __init__(self) -> None: + self._store: Dict[EntityId, List[HomeLoadPowerPoint]] = {} + + def _sorted_points(self, device_id: EntityId) -> List[HomeLoadPowerPoint]: + bucket = self._store.setdefault(device_id, []) + bucket.sort(key=lambda p: p.timestamp) + return bucket + + def add_power_point(self, device_id: EntityId, power_point: HomeLoadPowerPoint) -> None: + self._store.setdefault(device_id, []).append(power_point) + + def add_power_points(self, device_id: EntityId, power_points: List[HomeLoadPowerPoint]) -> None: + if not power_points: + return + self._store.setdefault(device_id, []).extend(power_points) + + def get_power_points(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + return [p for p in self._sorted_points(device_id) if start <= p.timestamp < end] + + def get_latest_timestamp(self, device_id: EntityId) -> Optional[Timestamp]: + points = self._store.get(device_id) + if not points: + return None + return max(p.timestamp for p in points) + + def purge_before(self, device_id: EntityId, timestamp: Timestamp) -> int: + bucket = self._store.get(device_id) + if not bucket: + return 0 + kept = [p for p in bucket if p.timestamp >= timestamp] + removed = len(bucket) - len(kept) + self._store[device_id] = kept + return removed + + def remove_power_points_by_time_range(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> None: + bucket = self._store.get(device_id) + if not bucket: + return + self._store[device_id] = [p for p in bucket if not (start <= p.timestamp < end)] + + def clear_device_history(self, device_id: EntityId) -> int: + bucket = self._store.pop(device_id, []) + return len(bucket) + + +class SqliteEnergyLoadHistoryRepository(EnergyLoadHistoryRepository): + """SQLite implementation of the device-scoped power-point time series. + + Uses a composite primary key (device_id, timestamp) so re-ingesting the + same window is idempotent (``INSERT OR IGNORE``). Retention and range + queries lean on the implicit PK index for O(log n) behavior. + """ + + def __init__(self, db: BaseSqliteRepository): + self._db = db + self.logger = db.logger + self._create_tables() + + def _create_tables(self) -> None: + self.logger.debug(f"Ensuring SQLite tables exist for Energy Load History Repository in {self._db.db_path}...") + sql = """ + CREATE TABLE IF NOT EXISTS home_load_power_points ( + device_id TEXT NOT NULL, + timestamp TIMESTAMP NOT NULL, + power REAL NOT NULL, + PRIMARY KEY (device_id, timestamp) + ); + """ + conn = self._db.get_connection() + try: + with conn: + conn.execute(sql) + except sqlite3.Error as e: + self.logger.error(f"Error creating SQLite tables: {e}") + raise ConfigurationError(f"DB error creating tables: {e}") from e + finally: + if conn: + conn.close() + + def add_power_point(self, device_id: EntityId, power_point: HomeLoadPowerPoint) -> None: + self.add_power_points(device_id, [power_point]) + + def add_power_points(self, device_id: EntityId, power_points: List[HomeLoadPowerPoint]) -> None: + if not power_points: + return + sql = """ + INSERT OR IGNORE INTO home_load_power_points (device_id, timestamp, power) + VALUES (?, ?, ?); + """ + conn = self._db.get_connection() + try: + rows = [(str(device_id), p.timestamp, float(p.power)) for p in power_points] + with conn: + conn.executemany(sql, rows) + except sqlite3.Error as e: + self.logger.error(f"SQLite error inserting power points for device {device_id}: {e}") + raise ConfigurationError(f"DB error inserting power points: {e}") from e + finally: + if conn: + conn.close() + + def get_power_points(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + sql = """ + SELECT timestamp, power + FROM home_load_power_points + WHERE device_id = ? AND timestamp >= ? AND timestamp < ? + ORDER BY timestamp ASC; + """ + conn = self._db.get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql, (str(device_id), start, end)) + rows = cursor.fetchall() + return [ + HomeLoadPowerPoint(timestamp=Timestamp(row["timestamp"]), power=Watts(row["power"])) for row in rows + ] + except sqlite3.Error as e: + self.logger.error(f"SQLite error reading power points for device {device_id}: {e}") + return [] + finally: + if conn: + conn.close() + + def get_latest_timestamp(self, device_id: EntityId) -> Optional[Timestamp]: + sql = "SELECT MAX(timestamp) AS ts FROM home_load_power_points WHERE device_id = ?;" + conn = self._db.get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql, (str(device_id),)) + row = cursor.fetchone() + if not row or row["ts"] is None: + return None + return Timestamp(row["ts"]) + except sqlite3.Error as e: + self.logger.error(f"SQLite error getting latest timestamp for device {device_id}: {e}") + return None + finally: + if conn: + conn.close() + + def purge_before(self, device_id: EntityId, timestamp: Timestamp) -> int: + sql = "DELETE FROM home_load_power_points WHERE device_id = ? AND timestamp < ?;" + conn = self._db.get_connection() + try: + with conn: + cursor = conn.cursor() + cursor.execute(sql, (str(device_id), timestamp)) + return cursor.rowcount or 0 + except sqlite3.Error as e: + self.logger.error(f"SQLite error purging power points for device {device_id}: {e}") + return 0 + finally: + if conn: + conn.close() + + def remove_power_points_by_time_range(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> None: + sql = "DELETE FROM home_load_power_points WHERE device_id = ? AND timestamp >= ? AND timestamp < ?;" + conn = self._db.get_connection() + try: + with conn: + conn.execute(sql, (str(device_id), start, end)) + except sqlite3.Error as e: + self.logger.error(f"SQLite error removing range for device {device_id}: {e}") + finally: + if conn: + conn.close() + + def clear_device_history(self, device_id: EntityId) -> int: + sql = "DELETE FROM home_load_power_points WHERE device_id = ?;" + conn = self._db.get_connection() + try: + with conn: + cursor = conn.cursor() + cursor.execute(sql, (str(device_id),)) + return cursor.rowcount or 0 + except sqlite3.Error as e: + self.logger.error(f"SQLite error clearing history for device {device_id}: {e}") + return 0 + finally: + if conn: + conn.close() + + +class SqlAlchemyEnergyLoadHistoryRepository(EnergyLoadHistoryRepository): + """SQLAlchemy Core implementation of the device-scoped power-point store. + + Core (not imperative mapping) is intentional: ``HomeLoadPowerPoint`` is a + Value Object, not an Entity — we serialize/deserialize manually and avoid + polluting the domain with ORM state. + """ + + def __init__(self, db: BaseSQLAlchemyRepository): + self._db = db + self.logger = db.logger + + def add_power_point(self, device_id: EntityId, power_point: HomeLoadPowerPoint) -> None: + self.add_power_points(device_id, [power_point]) + + def add_power_points(self, device_id: EntityId, power_points: List[HomeLoadPowerPoint]) -> None: + if not power_points: + return + rows = [{"device_id": str(device_id), "timestamp": p.timestamp, "power": float(p.power)} for p in power_points] + session = self._db.get_session() + try: + dialect_name = session.bind.dialect.name if session.bind else "" + if dialect_name == "sqlite": + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + + stmt = sqlite_insert(home_load_power_points_table).on_conflict_do_nothing( + index_elements=["device_id", "timestamp"] + ) + elif dialect_name == "postgresql": + from sqlalchemy.dialects.postgresql import insert as pg_insert + + stmt = pg_insert(home_load_power_points_table).on_conflict_do_nothing( + index_elements=["device_id", "timestamp"] + ) else: - # Create new profile with fixed UUID - new_profile = HomeLoadsProfile( - id=EntityId(self._DEFAULT_PROFILE_UUID), name=profile.name, devices=profile.devices + stmt = insert(home_load_power_points_table) + session.execute(stmt, rows) + session.commit() + finally: + session.close() + + def get_power_points(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + session = self._db.get_session() + try: + stmt = ( + select( + home_load_power_points_table.c.timestamp, + home_load_power_points_table.c.power, + ) + .where(home_load_power_points_table.c.device_id == str(device_id)) + .where(home_load_power_points_table.c.timestamp >= start) + .where(home_load_power_points_table.c.timestamp < end) + .order_by(home_load_power_points_table.c.timestamp.asc()) + ) + rows = session.execute(stmt).all() + return [ + HomeLoadPowerPoint( + timestamp=Timestamp(ts if ts.tzinfo else ts.replace(tzinfo=timezone.utc)), + power=Watts(power), ) - session.add(new_profile) + for ts, power in rows + ] + finally: + session.close() + + def get_latest_timestamp(self, device_id: EntityId) -> Optional[Timestamp]: + session = self._db.get_session() + try: + stmt = select(func.max(home_load_power_points_table.c.timestamp)).where( + home_load_power_points_table.c.device_id == str(device_id) + ) + latest = session.execute(stmt).scalar_one_or_none() + if latest is None: + return None + if isinstance(latest, datetime) and latest.tzinfo is None: + latest = latest.replace(tzinfo=timezone.utc) + return Timestamp(latest) + finally: + session.close() + + def purge_before(self, device_id: EntityId, timestamp: Timestamp) -> int: + session = self._db.get_session() + try: + stmt = delete(home_load_power_points_table).where( + home_load_power_points_table.c.device_id == str(device_id), + home_load_power_points_table.c.timestamp < timestamp, + ) + result = session.execute(stmt) + session.commit() + return result.rowcount or 0 + finally: + session.close() + + def remove_power_points_by_time_range(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> None: + session = self._db.get_session() + try: + stmt = delete(home_load_power_points_table).where( + home_load_power_points_table.c.device_id == str(device_id), + home_load_power_points_table.c.timestamp >= start, + home_load_power_points_table.c.timestamp < end, + ) + session.execute(stmt) + session.commit() + finally: + session.close() + + def clear_device_history(self, device_id: EntityId) -> int: + session = self._db.get_session() + try: + stmt = delete(home_load_power_points_table).where( + home_load_power_points_table.c.device_id == str(device_id), + ) + result = session.execute(stmt) + session.commit() + return result.rowcount or 0 + finally: + session.close() + + +# --- EnergyLoadHistoryProvider Repositories --- + + +class InMemoryEnergyLoadHistoryProviderRepository(EnergyLoadHistoryProviderRepository): + """In-memory implementation of EnergyLoadHistoryProviderRepository.""" + + def __init__(self): + self._providers: List[EnergyLoadHistoryProvider] = [] + + def add(self, energy_load_history_provider: EnergyLoadHistoryProvider) -> None: + self._providers.append(energy_load_history_provider) + + def get_by_id(self, energy_load_history_provider_id: EntityId) -> Optional[EnergyLoadHistoryProvider]: + for provider in self._providers: + if provider.id == energy_load_history_provider_id: + return provider + return None + + def get_all(self) -> List[EnergyLoadHistoryProvider]: + return self._providers + + def update(self, energy_load_history_provider: EnergyLoadHistoryProvider) -> None: + for i, existing in enumerate(self._providers): + if existing.id == energy_load_history_provider.id: + self._providers[i] = energy_load_history_provider + return + + def remove(self, energy_load_history_provider_id: EntityId) -> None: + self._providers = [p for p in self._providers if p.id != energy_load_history_provider_id] + + def get_by_external_service_id(self, external_service_id: EntityId) -> List[EnergyLoadHistoryProvider]: + if not external_service_id: + return [] + return [p for p in self._providers if p.external_service_id == external_service_id] + + +class SqliteEnergyLoadHistoryProviderRepository(EnergyLoadHistoryProviderRepository): + """SQLite implementation of EnergyLoadHistoryProviderRepository.""" + + def __init__(self, db: BaseSqliteRepository): + self._db = db + self.logger = db.logger + self._create_tables() + + def _create_tables(self): + self.logger.debug( + f"Ensuring SQLite tables exist for Energy Load History Provider Repository in {self._db.db_path}..." + ) + sql_statements = [ + """ + CREATE TABLE IF NOT EXISTS energy_load_history_providers ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + adapter_type TEXT NOT NULL, + config TEXT, + external_service_id TEXT + ); + """ + ] + conn = self._db.get_connection() + try: + with conn: + cursor = conn.cursor() + for statement in sql_statements: + cursor.execute(statement) + except sqlite3.Error as e: + self.logger.error(f"Error creating SQLite tables: {e}") + raise ConfigurationError(f"DB error creating tables: {e}") from e + finally: + if conn: + conn.close() + + def _deserialize_config( + self, adapter_type: EnergyLoadHistoryProviderAdapter, config_json: str + ) -> Optional[EnergyLoadHistoryProviderConfig]: + if not config_json: + return None + data: dict = json.loads(config_json) + + if adapter_type not in ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP: + raise EnergyLoadHistoryProviderNotFoundError( + f"Error reading EnergyLoadHistoryProvider configuration. Invalid type '{adapter_type}'" + ) + + config_class: Optional[type[EnergyLoadHistoryProviderConfig]] = ( + ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP.get(adapter_type) + ) + if not config_class: + return None + + config_instance = config_class.from_dict(data) + if not isinstance(config_instance, EnergyLoadHistoryProviderConfig): + raise EnergyLoadHistoryProviderConfigurationError( + f"Deserialized config is not of type EnergyLoadHistoryProviderConfig " + f"for adapter type {adapter_type}." + ) + return config_instance + + def _row_to_provider(self, row: sqlite3.Row) -> Optional[EnergyLoadHistoryProvider]: + if not row: + return None + try: + provider_type = EnergyLoadHistoryProviderAdapter(row["adapter_type"]) + config = self._deserialize_config(provider_type, row["config"]) + return EnergyLoadHistoryProvider( + id=EntityId(row["id"]), + name=row["name"], + adapter_type=provider_type, + config=config, + external_service_id=(EntityId(row["external_service_id"]) if row["external_service_id"] else None), + ) + except (ValueError, KeyError) as e: + self.logger.error(f"Error deserializing EnergyLoadHistoryProvider from DB row: {row}. Error: {e}") + return None + + def add(self, energy_load_history_provider: EnergyLoadHistoryProvider) -> None: + self.logger.debug(f"Adding history provider {energy_load_history_provider.id} to SQLite repository.") + sql = """ + INSERT INTO energy_load_history_providers (id, name, adapter_type, config, external_service_id) + VALUES (?, ?, ?, ?, ?); + """ + conn = self._db.get_connection() + try: + config_json: str = "" + if energy_load_history_provider.config: + config_json = json.dumps(energy_load_history_provider.config.to_dict()) + with conn: + conn.execute( + sql, + ( + energy_load_history_provider.id, + energy_load_history_provider.name, + energy_load_history_provider.adapter_type.value, + config_json, + energy_load_history_provider.external_service_id, + ), + ) + except sqlite3.IntegrityError as e: + self.logger.error( + f"Integrity error adding energy load history provider {energy_load_history_provider.id}: {e}" + ) + raise EnergyLoadHistoryProviderAlreadyExistsError( + f"Energy load history provider with ID {energy_load_history_provider.id} " + f"already exists or constraint violation: {e}" + ) from e + except sqlite3.Error as e: + self.logger.error( + f"SQLite error adding energy load history provider {energy_load_history_provider.id}: {e}" + ) + raise EnergyLoadHistoryProviderError(f"DB error adding energy load history provider: {e}") from e + finally: + if conn: + conn.close() + + def get_by_id(self, energy_load_history_provider_id: EntityId) -> Optional[EnergyLoadHistoryProvider]: + sql = "SELECT * FROM energy_load_history_providers WHERE id = ?;" + conn = self._db.get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql, (energy_load_history_provider_id,)) + row = cursor.fetchone() + return self._row_to_provider(row) + except sqlite3.Error as e: + self.logger.error( + f"SQLite error retrieving energy load history provider {energy_load_history_provider_id}: {e}" + ) + raise EnergyLoadHistoryProviderNotFoundError( + f"DB error retrieving energy load history provider: {e}" + ) from e + finally: + if conn: + conn.close() + + def get_all(self) -> List[EnergyLoadHistoryProvider]: + sql = "SELECT * FROM energy_load_history_providers;" + conn = self._db.get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql) + rows = cursor.fetchall() + providers = [] + for row in rows: + provider = self._row_to_provider(row) + if provider: + providers.append(provider) + return providers + except sqlite3.Error as e: + self.logger.error(f"SQLite error retrieving all energy load history providers: {e}") + return [] + finally: + if conn: + conn.close() + + def update(self, energy_load_history_provider: EnergyLoadHistoryProvider) -> None: + sql = """ + UPDATE energy_load_history_providers + SET name = ?, adapter_type = ?, config = ?, external_service_id = ? + WHERE id = ?; + """ + conn = self._db.get_connection() + try: + config_json = "" + if energy_load_history_provider.config: + config_json = json.dumps(energy_load_history_provider.config.to_dict()) + with conn: + cursor = conn.cursor() + cursor.execute( + sql, + ( + energy_load_history_provider.name, + energy_load_history_provider.adapter_type.value, + config_json, + energy_load_history_provider.external_service_id, + energy_load_history_provider.id, + ), + ) + if cursor.rowcount == 0: + raise EnergyLoadHistoryProviderNotFoundError( + f"Energy Load History Provider with ID {energy_load_history_provider.id} not found." + ) + except sqlite3.Error as e: + self.logger.error( + f"SQLite error updating energy load history provider {energy_load_history_provider.id}: {e}" + ) + raise EnergyLoadHistoryProviderError(f"DB error updating energy load history provider: {e}") from e + finally: + if conn: + conn.close() + + def remove(self, energy_load_history_provider_id: EntityId) -> None: + sql = "DELETE FROM energy_load_history_providers WHERE id = ?;" + conn = self._db.get_connection() + try: + with conn: + cursor = conn.cursor() + cursor.execute(sql, (energy_load_history_provider_id,)) + if cursor.rowcount == 0: + self.logger.warning( + f"Attempted to remove non-existent energy load history provider " + f"{energy_load_history_provider_id}." + ) + except sqlite3.Error as e: + self.logger.error( + f"SQLite error removing energy load history provider {energy_load_history_provider_id}: {e}" + ) + raise EnergyLoadHistoryProviderError(f"DB error removing energy load history provider: {e}") from e + finally: + if conn: + conn.close() + + def get_by_external_service_id(self, external_service_id: EntityId) -> List[EnergyLoadHistoryProvider]: + sql = "SELECT * FROM energy_load_history_providers WHERE external_service_id = ?;" + conn = self._db.get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql, (external_service_id,)) + rows = cursor.fetchall() + providers = [] + for row in rows: + provider = self._row_to_provider(row) + if provider: + providers.append(provider) + return providers + except sqlite3.Error as e: + self.logger.error( + f"SQLite error retrieving energy load history providers by external service ID " + f"{external_service_id}: {e}" + ) + return [] + finally: + if conn: + conn.close() + + +class SqlAlchemyEnergyLoadHistoryProviderRepository(EnergyLoadHistoryProviderRepository): + """SQLAlchemy implementation of EnergyLoadHistoryProviderRepository.""" + + def __init__(self, db: BaseSQLAlchemyRepository): + self._db = db + self.logger = db.logger + + def add(self, energy_load_history_provider: EnergyLoadHistoryProvider) -> None: + session = self._db.get_session() + try: + session.add(energy_load_history_provider) + session.commit() + finally: + session.close() + + def get_by_id(self, energy_load_history_provider_id: EntityId) -> Optional[EnergyLoadHistoryProvider]: + session = self._db.get_session() + try: + stmt = select(EnergyLoadHistoryProvider).where( + energy_load_history_providers_table.c.id == str(energy_load_history_provider_id) + ) + entity = session.execute(stmt).scalar_one_or_none() + return entity + finally: + session.close() + + def get_all(self) -> List[EnergyLoadHistoryProvider]: + session = self._db.get_session() + try: + stmt = select(EnergyLoadHistoryProvider) + entities = session.execute(stmt).scalars().all() + return list(entities) + finally: + session.close() + + def update(self, energy_load_history_provider: EnergyLoadHistoryProvider) -> None: + session = self._db.get_session() + try: + stmt = select(EnergyLoadHistoryProvider).where( + energy_load_history_providers_table.c.id == str(energy_load_history_provider.id) + ) + existing_entity = session.execute(stmt).scalar_one_or_none() + if existing_entity: + existing_entity.name = energy_load_history_provider.name + existing_entity.adapter_type = energy_load_history_provider.adapter_type + existing_entity.config = energy_load_history_provider.config + existing_entity.external_service_id = energy_load_history_provider.external_service_id + session.commit() + finally: + session.close() + + def remove(self, energy_load_history_provider_id: EntityId) -> None: + session = self._db.get_session() + try: + stmt = select(EnergyLoadHistoryProvider).where( + energy_load_history_providers_table.c.id == str(energy_load_history_provider_id) + ) + entity = session.execute(stmt).scalar_one_or_none() + if entity: + session.delete(entity) + session.commit() + finally: + session.close() + + def get_by_external_service_id(self, external_service_id: EntityId) -> List[EnergyLoadHistoryProvider]: + session = self._db.get_session() + try: + stmt = select(EnergyLoadHistoryProvider).where( + energy_load_history_providers_table.c.external_service_id == str(external_service_id) + ) + entities = session.execute(stmt).scalars().all() + return list(entities) + finally: + session.close() + + +# --- LoadConsumptionModel Repositories --- + + +class InMemoryLoadConsumptionModelRepository(LoadConsumptionModelRepository): + """In-memory implementation of LoadConsumptionModelRepository.""" + + def __init__(self) -> None: + self._models: Dict[str, LoadConsumptionModel] = {} + + def add(self, model: LoadConsumptionModel) -> None: + self._models[str(model.id)] = copy.deepcopy(model) + + def get_by_id(self, model_id: EntityId) -> Optional[LoadConsumptionModel]: + model = self._models.get(str(model_id)) + return copy.deepcopy(model) if model else None + + def get_active_model( + self, + adapter_type: EnergyLoadForecastProviderAdapter, + device_id: Optional[EntityId] = None, + ) -> Optional[LoadConsumptionModel]: + for model in self._models.values(): + if model.adapter_type == adapter_type and model.is_active: + if device_id is None and model.device_id is None: + return copy.deepcopy(model) + if device_id is not None and model.device_id is not None: + if str(model.device_id) == str(device_id): + return copy.deepcopy(model) + return None + + def get_all(self, device_id: Optional[EntityId] = None) -> List[LoadConsumptionModel]: + models = list(self._models.values()) + if device_id is not None: + models = [m for m in models if m.device_id is not None and str(m.device_id) == str(device_id)] + return [copy.deepcopy(m) for m in models] + + def update(self, model: LoadConsumptionModel) -> None: + key = str(model.id) + if key in self._models: + self._models[key] = copy.deepcopy(model) + + def remove(self, model_id: EntityId) -> None: + self._models.pop(str(model_id), None) + + +class SqliteLoadConsumptionModelRepository(LoadConsumptionModelRepository): + """SQLite implementation of LoadConsumptionModelRepository.""" + + def __init__(self, db: BaseSqliteRepository): + self._db = db + self.logger = db.logger + self._create_tables() + + def _create_tables(self) -> None: + self.logger.debug(f"Ensuring SQLite tables exist for LoadConsumptionModel Repository in {self._db.db_path}...") + sql = """ + CREATE TABLE IF NOT EXISTS load_consumption_models ( + id TEXT PRIMARY KEY, + device_id TEXT, + adapter_type TEXT NOT NULL, + trained_at TIMESTAMP, + mae REAL, + rmse REAL, + samples_used INTEGER NOT NULL DEFAULT 0, + is_active INTEGER NOT NULL DEFAULT 0, + model_bytes BLOB + ); + """ + idx_sql = """ + CREATE INDEX IF NOT EXISTS ix_load_consumption_models_active + ON load_consumption_models (adapter_type, device_id, is_active); + """ + conn = self._db.get_connection() + try: + with conn: + conn.execute(sql) + conn.execute(idx_sql) + except sqlite3.Error as e: + self.logger.error(f"Error creating SQLite tables for LoadConsumptionModel: {e}") + raise ConfigurationError(f"DB error creating tables: {e}") from e + finally: + if conn: + conn.close() + + def _row_to_model(self, row: sqlite3.Row) -> Optional[LoadConsumptionModel]: + if not row: + return None + try: + return LoadConsumptionModel( + id=EntityId(uuid.UUID(row["id"])), + device_id=EntityId(uuid.UUID(row["device_id"])) if row["device_id"] else None, + adapter_type=EnergyLoadForecastProviderAdapter(row["adapter_type"]), + trained_at=row["trained_at"], + mae=row["mae"], + rmse=row["rmse"], + samples_used=row["samples_used"], + is_active=bool(row["is_active"]), + model_bytes=row["model_bytes"], + ) + except (ValueError, KeyError) as e: + self.logger.error(f"Error deserializing LoadConsumptionModel from DB row: {e}") + return None + + def add(self, model: LoadConsumptionModel) -> None: + sql = """ + INSERT INTO load_consumption_models + (id, device_id, adapter_type, trained_at, mae, rmse, samples_used, is_active, model_bytes) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); + """ + conn = self._db.get_connection() + try: + with conn: + conn.execute( + sql, + ( + str(model.id), + str(model.device_id) if model.device_id else None, + model.adapter_type.value + if isinstance(model.adapter_type, EnergyLoadForecastProviderAdapter) + else model.adapter_type, + model.trained_at, + model.mae, + model.rmse, + model.samples_used, + int(model.is_active), + model.model_bytes, + ), + ) + except sqlite3.Error as e: + self.logger.error(f"SQLite error adding LoadConsumptionModel {model.id}: {e}") + raise ConfigurationError(f"DB error adding LoadConsumptionModel: {e}") from e + finally: + if conn: + conn.close() + + def get_by_id(self, model_id: EntityId) -> Optional[LoadConsumptionModel]: + sql = "SELECT * FROM load_consumption_models WHERE id = ?;" + conn = self._db.get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql, (str(model_id),)) + row = cursor.fetchone() + return self._row_to_model(row) + except sqlite3.Error as e: + self.logger.error(f"SQLite error retrieving LoadConsumptionModel {model_id}: {e}") + return None + finally: + if conn: + conn.close() + + def get_active_model( + self, + adapter_type: EnergyLoadForecastProviderAdapter, + device_id: Optional[EntityId] = None, + ) -> Optional[LoadConsumptionModel]: + adapter_val = ( + adapter_type.value if isinstance(adapter_type, EnergyLoadForecastProviderAdapter) else adapter_type + ) + if device_id is not None: + sql = """ + SELECT * FROM load_consumption_models + WHERE adapter_type = ? AND device_id = ? AND is_active = 1 + LIMIT 1; + """ + params = (adapter_val, str(device_id)) + else: + sql = """ + SELECT * FROM load_consumption_models + WHERE adapter_type = ? AND device_id IS NULL AND is_active = 1 + LIMIT 1; + """ + params = (adapter_val,) + conn = self._db.get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql, params) + row = cursor.fetchone() + return self._row_to_model(row) + except sqlite3.Error as e: + self.logger.error(f"SQLite error retrieving active model for {adapter_type}: {e}") + return None + finally: + if conn: + conn.close() + + def get_all(self, device_id: Optional[EntityId] = None) -> List[LoadConsumptionModel]: + if device_id is not None: + sql = "SELECT * FROM load_consumption_models WHERE device_id = ? ORDER BY trained_at DESC;" + params = (str(device_id),) + else: + sql = "SELECT * FROM load_consumption_models ORDER BY trained_at DESC;" + params = () + conn = self._db.get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql, params) + rows = cursor.fetchall() + models: List[LoadConsumptionModel] = [] + for row in rows: + model = self._row_to_model(row) + if model: + models.append(model) + return models + except sqlite3.Error as e: + self.logger.error(f"SQLite error retrieving all LoadConsumptionModels: {e}") + return [] + finally: + if conn: + conn.close() + + def update(self, model: LoadConsumptionModel) -> None: + sql = """ + UPDATE load_consumption_models + SET device_id = ?, adapter_type = ?, trained_at = ?, mae = ?, rmse = ?, + samples_used = ?, is_active = ?, model_bytes = ? + WHERE id = ?; + """ + conn = self._db.get_connection() + try: + with conn: + conn.execute( + sql, + ( + str(model.device_id) if model.device_id else None, + model.adapter_type.value + if isinstance(model.adapter_type, EnergyLoadForecastProviderAdapter) + else model.adapter_type, + model.trained_at, + model.mae, + model.rmse, + model.samples_used, + int(model.is_active), + model.model_bytes, + str(model.id), + ), + ) + except sqlite3.Error as e: + self.logger.error(f"SQLite error updating LoadConsumptionModel {model.id}: {e}") + raise ConfigurationError(f"DB error updating LoadConsumptionModel: {e}") from e + finally: + if conn: + conn.close() + + def remove(self, model_id: EntityId) -> None: + sql = "DELETE FROM load_consumption_models WHERE id = ?;" + conn = self._db.get_connection() + try: + with conn: + conn.execute(sql, (str(model_id),)) + except sqlite3.Error as e: + self.logger.error(f"SQLite error removing LoadConsumptionModel {model_id}: {e}") + raise ConfigurationError(f"DB error removing LoadConsumptionModel: {e}") from e + finally: + if conn: + conn.close() + + +class SqlAlchemyLoadConsumptionModelRepository(LoadConsumptionModelRepository): + """SQLAlchemy implementation of LoadConsumptionModelRepository.""" + + def __init__(self, db: BaseSQLAlchemyRepository): + self._db = db + self.logger = db.logger + + def add(self, model: LoadConsumptionModel) -> None: + session = self._db.get_session() + try: + session.add(model) + session.commit() + finally: + session.close() + + def get_by_id(self, model_id: EntityId) -> Optional[LoadConsumptionModel]: + session = self._db.get_session() + try: + stmt = select(LoadConsumptionModel).where(load_consumption_models_table.c.id == str(model_id)) + entity = session.execute(stmt).scalar_one_or_none() + return entity + finally: + session.close() + + def get_active_model( + self, + adapter_type: EnergyLoadForecastProviderAdapter, + device_id: Optional[EntityId] = None, + ) -> Optional[LoadConsumptionModel]: + session = self._db.get_session() + try: + adapter_val = ( + adapter_type.value if isinstance(adapter_type, EnergyLoadForecastProviderAdapter) else adapter_type + ) + stmt = ( + select(LoadConsumptionModel) + .where(load_consumption_models_table.c.adapter_type == adapter_val) + .where(load_consumption_models_table.c.is_active == True) # noqa: E712 + ) + if device_id is not None: + stmt = stmt.where(load_consumption_models_table.c.device_id == str(device_id)) + else: + stmt = stmt.where(load_consumption_models_table.c.device_id.is_(None)) + entity = session.execute(stmt).scalar_one_or_none() + return entity + finally: + session.close() + + def get_all(self, device_id: Optional[EntityId] = None) -> List[LoadConsumptionModel]: + session = self._db.get_session() + try: + stmt = select(LoadConsumptionModel) + if device_id is not None: + stmt = stmt.where(load_consumption_models_table.c.device_id == str(device_id)) + stmt = stmt.order_by(load_consumption_models_table.c.trained_at.desc()) + return list(session.execute(stmt).scalars().all()) + finally: + session.close() + + def update(self, model: LoadConsumptionModel) -> None: + session = self._db.get_session() + try: + stmt = select(LoadConsumptionModel).where(load_consumption_models_table.c.id == str(model.id)) + existing = session.execute(stmt).scalar_one_or_none() + if existing: + existing.device_id = model.device_id + existing.adapter_type = model.adapter_type + existing.trained_at = model.trained_at + existing.mae = model.mae + existing.rmse = model.rmse + existing.samples_used = model.samples_used + existing.is_active = model.is_active + existing.model_bytes = model.model_bytes + session.commit() + finally: + session.close() + + def remove(self, model_id: EntityId) -> None: + session = self._db.get_session() + try: + stmt = select(LoadConsumptionModel).where(load_consumption_models_table.c.id == str(model_id)) + entity = session.execute(stmt).scalar_one_or_none() + if entity: + session.delete(entity) session.commit() finally: session.close() diff --git a/edge_mining/adapters/domain/home_load/schemas.py b/edge_mining/adapters/domain/home_load/schemas.py index a5a3c1f..441b210 100644 --- a/edge_mining/adapters/domain/home_load/schemas.py +++ b/edge_mining/adapters/domain/home_load/schemas.py @@ -2,59 +2,180 @@ import uuid from datetime import datetime -from typing import Dict, Optional, Union, cast +from typing import Dict, List, Optional, Union, cast -from pydantic import BaseModel, Field, field_serializer, field_validator +from pydantic import BaseModel, Field, computed_field, field_serializer, field_validator -from edge_mining.domain.common import EntityId, Timestamp, Watts -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter -from edge_mining.domain.home_load.entities import HomeForecastProvider -from edge_mining.domain.home_load.value_objects import ConsumptionForecast -from edge_mining.shared.adapter_configs.home_load import HomeForecastProviderDummyConfig -from edge_mining.shared.adapter_maps.home_load import HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP -from edge_mining.shared.interfaces.config import HomeForecastProviderConfig +from edge_mining.domain.common import EntityId, Timestamp, WattHours +from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile +from edge_mining.domain.home_load.common import ( + EnergyLoadForecastProviderAdapter, + EnergyLoadHistoryProviderAdapter, + LoadDeviceCategory, +) +from edge_mining.domain.home_load.entities import ( + EnergyLoadForecastProvider, + EnergyLoadHistoryProvider, + LoadConsumptionModel, + LoadDevice, +) +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + HomeLoadsConsumption, + LoadDeviceConsumption, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import ( + EnergyLoadForecastProviderDummyConfig, + EnergyLoadForecastProviderNaiveLastHourConfig, + EnergyLoadForecastProviderNaivePersistenceConfig, + EnergyLoadForecastProviderSeasonalBaselineConfig, + EnergyLoadForecastProviderSkforecastConfig, + EnergyLoadForecastProviderStatsmodelsConfig, + EnergyLoadForecastProviderTypicalProfileConfig, + EnergyLoadForecastProviderXGBoostConfig, + EnergyLoadHistoryProviderHomeAssistantAPIConfig, +) +from edge_mining.shared.adapter_maps.home_load import ( + ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP, + ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP, +) +from edge_mining.shared.interfaces.config import EnergyLoadForecastProviderConfig, EnergyLoadHistoryProviderConfig -class ConsumptionForecastSchema(BaseModel): - """Schema for ConsumptionForecast value object.""" +class HomeLoadEnergyIntervalSchema(BaseModel): + """Schema for HomeLoadEnergyInterval value object.""" - predicted_watts: Dict[str, float] = Field( - default_factory=dict, description="Predicted consumption watts by timestamp (ISO format keys)" + start: datetime = Field(..., description="Interval start timestamp") + end: datetime = Field(..., description="Interval end timestamp") + energy: Optional[float] = Field(default=None, description="Energy in watt-hours") + avg_power: Optional[float] = Field(default=None, description="Average power in watts") + + +class LoadEnergyConsumptionSchema(BaseModel): + """Schema for LoadEnergyConsumption value object.""" + + timestamp: datetime = Field(..., description="When this consumption data was generated") + intervals: List[HomeLoadEnergyIntervalSchema] = Field( + default_factory=list, description="List of consumption intervals" ) - generated_at: datetime = Field(default_factory=datetime.now, description="When the forecast was generated") @classmethod - def from_model(cls, forecast: ConsumptionForecast) -> "ConsumptionForecastSchema": - """Create schema from ConsumptionForecast value object.""" - # Convert Timestamp keys to ISO string and Watts values to float - predicted_watts_dict = {ts.isoformat(): float(watts) for ts, watts in forecast.predicted_watts.items()} + def from_model(cls, consumption: LoadEnergyConsumption) -> "LoadEnergyConsumptionSchema": + """Create schema from domain model.""" + intervals = [ + HomeLoadEnergyIntervalSchema( + start=cast(datetime, interval.start), + end=cast(datetime, interval.end), + energy=float(interval.energy) if interval.energy is not None else None, + avg_power=float(interval.avg_power), + ) + for interval in consumption.intervals + ] + return cls( - predicted_watts=predicted_watts_dict, - generated_at=forecast.generated_at, + timestamp=cast(datetime, consumption.timestamp), + intervals=intervals, ) - def to_model(self) -> ConsumptionForecast: - """Convert schema to ConsumptionForecast value object.""" - # Convert ISO string keys back to Timestamp and float values to Watts - predicted_watts_dict = { - Timestamp(datetime.fromisoformat(ts)): Watts(watts) for ts, watts in self.predicted_watts.items() - } - return ConsumptionForecast( - predicted_watts=predicted_watts_dict, - generated_at=Timestamp(self.generated_at), + def to_model(self) -> LoadEnergyConsumption: + """Convert schema to domain model.""" + intervals: List[HomeLoadEnergyInterval] = [] + for interval_schema in self.intervals: + intervals.append( + HomeLoadEnergyInterval( + start=Timestamp(interval_schema.start), + end=Timestamp(interval_schema.end), + energy=None if interval_schema.energy is None else WattHours(interval_schema.energy), + ) + ) + + return LoadEnergyConsumption( + timestamp=Timestamp(self.timestamp), + intervals=intervals, ) -class HomeForecastProviderSchema(BaseModel): - """Schema for HomeForecastProvider entity with complete validation.""" +class LoadDeviceConsumptionSchema(BaseModel): + """Schema for LoadDeviceConsumption value object (device-bound history + forecast).""" - id: str = Field(..., description="Unique identifier for the home forecast provider") - name: str = Field(default="", description="Home forecast provider name") - adapter_type: HomeForecastProviderAdapter = Field( - default=HomeForecastProviderAdapter.DUMMY, description="Type of home forecast provider adapter" + device_id: str = Field(..., description="Device UUID") + device_name: str = Field(..., description="Device unique name within profile") + device_category: LoadDeviceCategory = Field(..., description="Device category") + history: LoadEnergyConsumptionSchema = Field( + default_factory=lambda: LoadEnergyConsumptionSchema(timestamp=datetime.now(), intervals=[]), + description="Measured consumption time series.", + ) + forecast: LoadEnergyConsumptionSchema = Field( + default_factory=lambda: LoadEnergyConsumptionSchema(timestamp=datetime.now(), intervals=[]), + description="Predicted consumption time series.", + ) + + @classmethod + def from_model(cls, consumption: LoadDeviceConsumption) -> "LoadDeviceConsumptionSchema": + return cls( + device_id=str(consumption.device_id), + device_name=consumption.device_name, + device_category=consumption.device_category, + history=LoadEnergyConsumptionSchema.from_model(consumption.history), + forecast=LoadEnergyConsumptionSchema.from_model(consumption.forecast), + ) + + def to_model(self) -> LoadDeviceConsumption: + return LoadDeviceConsumption( + device_id=EntityId(uuid.UUID(self.device_id)), + device_name=self.device_name, + device_category=self.device_category, + history=self.history.to_model(), + forecast=self.forecast.to_model(), + ) + + +class HomeLoadsConsumptionSchema(BaseModel): + """Schema for HomeLoadsConsumption value object (unified household view).""" + + per_device: List[LoadDeviceConsumptionSchema] = Field(default_factory=list) + total_history: LoadEnergyConsumptionSchema = Field( + default_factory=lambda: LoadEnergyConsumptionSchema(timestamp=datetime.now(), intervals=[]), + description="Aggregated household history.", + ) + total_forecast: LoadEnergyConsumptionSchema = Field( + default_factory=lambda: LoadEnergyConsumptionSchema(timestamp=datetime.now(), intervals=[]), + description="Aggregated household forecast.", + ) + + @classmethod + def from_model(cls, consumption: HomeLoadsConsumption) -> "HomeLoadsConsumptionSchema": + return cls( + per_device=[LoadDeviceConsumptionSchema.from_model(d) for d in consumption.per_device], + total_history=LoadEnergyConsumptionSchema.from_model(consumption.total_history), + total_forecast=LoadEnergyConsumptionSchema.from_model(consumption.total_forecast), + ) + + def to_model(self) -> HomeLoadsConsumption: + return HomeLoadsConsumption( + per_device=[d.to_model() for d in self.per_device], + total_history=self.total_history.to_model(), + total_forecast=self.total_forecast.to_model(), + ) + + +class LoadDeviceSchema(BaseModel): + """Schema for LoadDevice entity with complete validation.""" + + id: str = Field(..., description="Unique identifier for the load device") + name: str = Field(default="", description="Load device name") + category: LoadDeviceCategory = Field( + default=LoadDeviceCategory.OCCASIONAL, description="Category of load device (e.g., controllable, continuous)" + ) + enabled: bool = Field(default=True, description="Whether the load device is active in the system") + energy_load_forecast_provider_id: Optional[str] = Field( + default=None, description="ID of the energy load forecast provider associated with this load device" + ) + energy_load_history_provider_id: Optional[str] = Field( + default=None, description="ID of the energy load history provider associated with this load device" ) - config: dict = Field(default={}, description="Home forecast provider configuration") - external_service_id: Optional[str] = Field(default=None, description="ID of external service") @field_validator("id") @classmethod @@ -62,27 +183,361 @@ def validate_id(cls, v: str) -> str: """Validate that id is a valid UUID string.""" try: uuid.UUID(v) - except ValueError as exc: - raise ValueError("id must be a valid UUID string") from exc + return v + except ValueError as e: + raise ValueError(f"Invalid UUID format: {v}") from e + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate device name.""" + if not v.strip(): + raise ValueError("Device name cannot be empty") + return v.strip() + + @field_validator("energy_load_forecast_provider_id", "energy_load_history_provider_id") + @classmethod + def validate_provider_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that provider ID is a valid UUID string if provided.""" + if v is not None: + try: + uuid.UUID(v) + except ValueError as exc: + raise ValueError("Provider ID must be a valid UUID string") from exc return v + @classmethod + def from_model(cls, load_device: LoadDevice) -> "LoadDeviceSchema": + """Create schema from domain model.""" + return cls( + id=str(load_device.id), + name=load_device.name, + category=load_device.category, + enabled=load_device.enabled, + energy_load_forecast_provider_id=( + str(load_device.energy_load_forecast_provider_id) + if load_device.energy_load_forecast_provider_id + else None + ), + energy_load_history_provider_id=( + str(load_device.energy_load_history_provider_id) + if load_device.energy_load_history_provider_id + else None + ), + ) + + @field_serializer("id") + def serialize_id(self, value: str) -> str: + """Serialize id field.""" + return value + + @field_serializer("energy_load_forecast_provider_id", "energy_load_history_provider_id") + def serialize_provider_id(self, value: Optional[str]) -> Optional[str]: + """Serialize provider ID field.""" + return value + + def to_model(self) -> LoadDevice: + """Convert schema to domain model.""" + forecast_provider_id = ( + EntityId(uuid.UUID(self.energy_load_forecast_provider_id)) + if self.energy_load_forecast_provider_id + else None + ) + history_provider_id = ( + EntityId(uuid.UUID(self.energy_load_history_provider_id)) if self.energy_load_history_provider_id else None + ) + return LoadDevice( + id=EntityId(uuid.UUID(self.id)), + name=self.name, + category=LoadDeviceCategory(self.category) if isinstance(self.category, str) else self.category, + enabled=self.enabled, + energy_load_forecast_provider_id=forecast_provider_id, + energy_load_history_provider_id=history_provider_id, + ) + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + + +class LoadDeviceCreateSchema(BaseModel): + """Schema for creating a new load device.""" + + name: str = Field(default="", description="Load device name") + category: LoadDeviceCategory = Field(default=LoadDeviceCategory.OCCASIONAL, description="Category of load device") + enabled: bool = Field(default=True, description="Whether the load device is active in the system") + energy_load_forecast_provider_id: Optional[str] = Field( + default=None, description="ID of the energy load forecast provider associated with this load device" + ) + energy_load_history_provider_id: Optional[str] = Field( + default=None, description="ID of the energy load history provider associated with this load device" + ) + @field_validator("name") @classmethod def validate_name(cls, v: str) -> str: - """Validate home forecast provider name.""" - v = v.strip() - if not v: - v = "" + """Validate device name.""" + if not v.strip(): + raise ValueError("Device name cannot be empty") + return v.strip() + + @field_validator("energy_load_forecast_provider_id", "energy_load_history_provider_id") + @classmethod + def validate_provider_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that provider ID is a valid UUID string if provided.""" + if v is not None: + try: + uuid.UUID(v) + except ValueError as exc: + raise ValueError("Provider ID must be a valid UUID string") from exc return v + @field_serializer("energy_load_forecast_provider_id", "energy_load_history_provider_id") + def serialize_provider_id(self, value: Optional[str]) -> Optional[str]: + """Serialize provider ID field.""" + return value + + def to_model(self) -> LoadDevice: + """Convert schema to domain model.""" + forecast_provider_id = ( + EntityId(uuid.UUID(self.energy_load_forecast_provider_id)) + if self.energy_load_forecast_provider_id + else None + ) + history_provider_id = ( + EntityId(uuid.UUID(self.energy_load_history_provider_id)) if self.energy_load_history_provider_id else None + ) + return LoadDevice( + id=EntityId(uuid.uuid4()), + name=self.name, + category=LoadDeviceCategory(self.category) if isinstance(self.category, str) else self.category, + enabled=self.enabled, + energy_load_forecast_provider_id=forecast_provider_id, + energy_load_history_provider_id=history_provider_id, + ) + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + + +class LoadDeviceUpdateSchema(BaseModel): + """Schema for updating an existing load device.""" + + name: str = Field(default="", description="Load device name") + category: LoadDeviceCategory = Field(default=LoadDeviceCategory.OCCASIONAL, description="Category of load device") + enabled: bool = Field(default=True, description="Whether the load device is active in the system") + energy_load_forecast_provider_id: Optional[str] = Field( + default=None, description="ID of the energy load forecast provider associated with this load device" + ) + energy_load_history_provider_id: Optional[str] = Field( + default=None, description="ID of the energy load history provider associated with this load device" + ) + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate device name.""" + if not v.strip(): + raise ValueError("Device name cannot be empty") + return v.strip() + + @field_validator("energy_load_forecast_provider_id", "energy_load_history_provider_id") + @classmethod + def validate_provider_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that provider ID is a valid UUID string if provided.""" + if v is not None: + try: + uuid.UUID(v) + except ValueError as exc: + raise ValueError("Provider ID must be a valid UUID string") from exc + return v + + @field_serializer("energy_load_forecast_provider_id", "energy_load_history_provider_id") + def serialize_provider_id(self, value: Optional[str]) -> Optional[str]: + """Serialize provider ID field.""" + return value + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + + +class HomeLoadsProfileSchema(BaseModel): + """Schema for HomeLoadsProfile aggregate root.""" + + id: str = Field(..., description="Unique identifier for the home loads profile") + name: str = Field(default="Default Home Profile", description="Profile name") + devices: List[LoadDeviceSchema] = Field(default_factory=list, description="Load devices in this profile") + + @field_validator("id") + @classmethod + def validate_id(cls, v: str) -> str: + """Validate that id is a valid UUID string.""" + try: + uuid.UUID(v) + return v + except ValueError as e: + raise ValueError(f"Invalid UUID format: {v}") from e + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate profile name.""" + if not v.strip(): + raise ValueError("Profile name cannot be empty") + return v.strip() + + @classmethod + def from_model(cls, profile: HomeLoadsProfile) -> "HomeLoadsProfileSchema": + """Create schema from domain model.""" + devices = [] + for device in profile.devices: + devices.append(LoadDeviceSchema.from_model(device)) + + return cls( + id=str(profile.id), + name=profile.name, + devices=devices, + ) + + @field_serializer("id") + def serialize_id(self, value: str) -> str: + """Serialize id field.""" + return value + + def to_model(self) -> HomeLoadsProfile: + """Convert schema to domain model.""" + devices = [] + for device_schema in self.devices: + devices.append(device_schema.to_model()) + + return HomeLoadsProfile( + id=EntityId(uuid.UUID(self.id)), + name=self.name, + devices=devices, + ) + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + + +class HomeLoadsProfileCreateSchema(BaseModel): + """Schema for creating a new home loads profile.""" + + name: str = Field(default="Default Home Profile", description="Profile name") + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate profile name.""" + if not v.strip(): + raise ValueError("Profile name cannot be empty") + return v.strip() + + def to_model(self) -> HomeLoadsProfile: + """Convert schema to domain model.""" + return HomeLoadsProfile( + id=EntityId(uuid.uuid4()), + name=self.name, + devices=[], + ) + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + + +class HomeLoadsProfileUpdateSchema(BaseModel): + """Schema for updating an existing home loads profile.""" + + name: str = Field(default="", description="Profile name") + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate profile name.""" + if not v.strip(): + raise ValueError("Profile name cannot be empty") + return v.strip() + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + + +class EnergyLoadForecastProviderSchema(BaseModel): + """Schema for EnergyLoadForecastProvider entity with complete validation.""" + + id: str = Field(..., description="Unique identifier for the energy load forecast provider") + name: str = Field(default="", description="Energy load forecast provider name") + adapter_type: EnergyLoadForecastProviderAdapter = Field( + default=EnergyLoadForecastProviderAdapter.DUMMY, + description="Type of energy load forecast provider adapter", + ) + config: dict = Field(default={}, description="Energy load forecast provider configuration") + external_service_id: Optional[str] = Field(default=None, description="ID of external service") + + @computed_field # type: ignore[prop-decorator] + @property + def min_required_history_hours(self) -> int: + """Minimum hours of historical data the provider needs to produce a forecast.""" + adapter = self.adapter_type + cfg = self.config or {} + + if adapter == EnergyLoadForecastProviderAdapter.NAIVE_LAST_HOUR: + return 1 + if adapter == EnergyLoadForecastProviderAdapter.NAIVE_PERSISTENCE: + delta_days = int(cfg.get("delta_days", 1)) + return delta_days * 24 + if adapter == EnergyLoadForecastProviderAdapter.SKFORECAST: + num_lags = int(cfg.get("num_lags", 72)) + hours_ahead = int(cfg.get("hours_ahead", 24)) + return num_lags + 48 + hours_ahead + if adapter == EnergyLoadForecastProviderAdapter.STATSMODELS: + seasonal_periods = int(cfg.get("seasonal_periods", 24)) + return seasonal_periods * 2 + if adapter == EnergyLoadForecastProviderAdapter.TYPICAL_PROFILE: + weeks_lookback = int(cfg.get("weeks_lookback", 8)) + return weeks_lookback * 168 + if adapter == EnergyLoadForecastProviderAdapter.XGBOOST: + hours_ahead = int(cfg.get("hours_ahead", 3)) + return 168 + 48 + hours_ahead + return 0 + + @field_validator("id") + @classmethod + def validate_id(cls, v: str) -> str: + """Validate that id is a valid UUID string.""" + try: + uuid.UUID(v) + return v + except ValueError as e: + raise ValueError(f"Invalid UUID format: {v}") from e + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate provider name.""" + if not v.strip(): + raise ValueError("Provider name cannot be empty") + return v.strip() + @field_validator("adapter_type") @classmethod - def validate_adapter_type(cls, v: str) -> HomeForecastProviderAdapter: - """Validate that adapter_type is a recognized HomeForecastProviderAdapter.""" - adapter_values = [adapter.value for adapter in HomeForecastProviderAdapter] + def validate_adapter_type(cls, v: str) -> EnergyLoadForecastProviderAdapter: + """Validate that adapter_type is a recognized EnergyLoadForecastProviderAdapter.""" + adapter_values = [adapter.value for adapter in EnergyLoadForecastProviderAdapter] if v not in adapter_values: raise ValueError(f"adapter_type must be one of {adapter_values}") - return HomeForecastProviderAdapter(v) + return EnergyLoadForecastProviderAdapter(v) @field_validator("external_service_id") @classmethod @@ -96,37 +551,39 @@ def validate_external_service_id(cls, v: Optional[str]) -> Optional[str]: return v @classmethod - def from_model(cls, home_forecast_provider: HomeForecastProvider) -> "HomeForecastProviderSchema": - """Create HomeForecastProviderSchema from a HomeForecastProvider domain model instance.""" + def from_model(cls, provider: EnergyLoadForecastProvider) -> "EnergyLoadForecastProviderSchema": + """Create schema from domain model.""" + config_dict = {} + if provider.config: + config_dict = provider.config.to_dict() + return cls( - id=str(home_forecast_provider.id), - name=home_forecast_provider.name, - adapter_type=home_forecast_provider.adapter_type, - config=home_forecast_provider.config.to_dict() if home_forecast_provider.config else {}, - external_service_id=( - str(home_forecast_provider.external_service_id) if home_forecast_provider.external_service_id else None - ), + id=str(provider.id), + name=provider.name, + adapter_type=provider.adapter_type, + config=config_dict, + external_service_id=str(provider.external_service_id) if provider.external_service_id else None, ) @field_serializer("id") def serialize_id(self, value: str) -> str: """Serialize id field.""" - return str(value) + return value @field_serializer("external_service_id") def serialize_external_service_id(self, value: Optional[str]) -> Optional[str]: - """Serialize external_service_id field.""" - return str(value) if value is not None else None + """Serialize external service id field.""" + return value - def to_model(self) -> HomeForecastProvider: - """Convert HomeForecastProviderSchema to HomeForecastProvider domain model instance.""" - configuration: Optional[HomeForecastProviderConfig] = None + def to_model(self) -> EnergyLoadForecastProvider: + """Convert schema to domain model.""" + configuration: Optional[EnergyLoadForecastProviderConfig] = None if self.config: - config_class = HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get(self.adapter_type, None) - if config_class: - configuration = cast(HomeForecastProviderConfig, config_class.from_dict(self.config)) + config_type = ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get(self.adapter_type) + if config_type: + configuration = cast(EnergyLoadForecastProviderConfig, config_type.from_dict(self.config)) - return HomeForecastProvider( + return EnergyLoadForecastProvider( id=EntityId(uuid.UUID(self.id)), name=self.name, adapter_type=self.adapter_type, @@ -142,37 +599,37 @@ class Config: arbitrary_types_allowed = True json_encoders = { uuid.UUID: str, - HomeForecastProviderAdapter: lambda v: v.value, + EnergyLoadForecastProviderAdapter: lambda v: v.value, } -class HomeForecastProviderCreateSchema(BaseModel): - """Schema for creating a new home forecast provider.""" +class EnergyLoadForecastProviderCreateSchema(BaseModel): + """Schema for creating a new energy load forecast provider.""" - name: str = Field(default="", description="Home forecast provider name") - adapter_type: HomeForecastProviderAdapter = Field( - default=HomeForecastProviderAdapter.DUMMY, description="Type of home forecast provider adapter" + name: str = Field(default="", description="Energy load forecast provider name") + adapter_type: EnergyLoadForecastProviderAdapter = Field( + default=EnergyLoadForecastProviderAdapter.DUMMY, + description="Type of energy load forecast provider adapter", ) - config: Optional[dict] = Field(default=None, description="Home forecast provider configuration") + config: Optional[dict] = Field(default=None, description="Energy load forecast provider configuration") external_service_id: Optional[str] = Field(default=None, description="ID of external service") @field_validator("name") @classmethod def validate_name(cls, v: str) -> str: - """Validate home forecast provider name.""" - v = v.strip() - if not v: - v = "" - return v + """Validate provider name.""" + if not v.strip(): + raise ValueError("Provider name cannot be empty") + return v.strip() @field_validator("adapter_type") @classmethod - def validate_adapter_type(cls, v: str) -> HomeForecastProviderAdapter: - """Validate that adapter_type is a recognized HomeForecastProviderAdapter.""" - adapter_values = [adapter.value for adapter in HomeForecastProviderAdapter] + def validate_adapter_type(cls, v: str) -> EnergyLoadForecastProviderAdapter: + """Validate that adapter_type is a recognized EnergyLoadForecastProviderAdapter.""" + adapter_values = [adapter.value for adapter in EnergyLoadForecastProviderAdapter] if v not in adapter_values: raise ValueError(f"adapter_type must be one of {adapter_values}") - return HomeForecastProviderAdapter(v) + return EnergyLoadForecastProviderAdapter(v) @field_validator("external_service_id") @classmethod @@ -185,15 +642,15 @@ def validate_external_service_id(cls, v: Optional[str]) -> Optional[str]: raise ValueError("external_service_id must be a valid UUID string") from exc return v - def to_model(self) -> HomeForecastProvider: - """Convert HomeForecastProviderCreateSchema to a HomeForecastProvider domain model instance.""" - configuration: Optional[HomeForecastProviderConfig] = None + def to_model(self) -> EnergyLoadForecastProvider: + """Convert schema to domain model.""" + configuration: Optional[EnergyLoadForecastProviderConfig] = None if self.config: - config_class = HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get(self.adapter_type, None) - if config_class: - configuration = cast(HomeForecastProviderConfig, config_class.from_dict(self.config)) + config_type = ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get(self.adapter_type) + if config_type: + configuration = cast(EnergyLoadForecastProviderConfig, config_type.from_dict(self.config)) - return HomeForecastProvider( + return EnergyLoadForecastProvider( id=EntityId(uuid.uuid4()), name=self.name, adapter_type=self.adapter_type, @@ -208,25 +665,24 @@ class Config: validate_assignment = True json_encoders = { uuid.UUID: str, - HomeForecastProviderAdapter: lambda v: v.value, + EnergyLoadForecastProviderAdapter: lambda v: v.value, } -class HomeForecastProviderUpdateSchema(BaseModel): - """Schema for updating an existing home forecast provider.""" +class EnergyLoadForecastProviderUpdateSchema(BaseModel): + """Schema for updating an existing energy load forecast provider.""" - name: str = Field(default="", description="Home forecast provider name") - config: Optional[dict] = Field(default=None, description="Home forecast provider configuration") + name: str = Field(default="", description="Energy load forecast provider name") + config: Optional[dict] = Field(default=None, description="Energy load forecast provider configuration") external_service_id: Optional[str] = Field(default=None, description="ID of external service") @field_validator("name") @classmethod def validate_name(cls, v: str) -> str: - """Validate home forecast provider name.""" - v = v.strip() - if not v: - v = "" - return v + """Validate provider name.""" + if not v.strip(): + raise ValueError("Provider name cannot be empty") + return v.strip() @field_validator("external_service_id") @classmethod @@ -244,30 +700,395 @@ class Config: use_enum_values = True validate_assignment = True - json_encoders = { - uuid.UUID: str, - } -class HomeForecastProviderDummyConfigSchema(BaseModel): - """Schema for Dummy Home Forecast Provider Config.""" +class EnergyLoadForecastProviderDummyConfigSchema(BaseModel): + """Schema for Dummy EnergyLoadForecastProviderConfig.""" load_power_max: float = Field(default=500.0, ge=0, description="Maximum load power in Watts") @field_validator("load_power_max") @classmethod def validate_load_power_max(cls, v: float) -> float: - """Validate load_power_max is non-negative.""" + """Validate load power max is non-negative.""" if v < 0: - raise ValueError("load_power_max must be non-negative") + raise ValueError("Maximum load power cannot be negative") + return v + + def to_model(self) -> EnergyLoadForecastProviderDummyConfig: + """Convert schema to domain model.""" + return EnergyLoadForecastProviderDummyConfig(load_power_max=self.load_power_max) + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + validate_assignment = True + + +class EnergyLoadForecastProviderNaiveLastHourConfigSchema(BaseModel): + """Schema for NaiveLastHour EnergyLoadForecastProviderConfig.""" + + hours_ahead: int = Field(default=3, ge=1, le=72, description="Number of hours to forecast ahead") + + def to_model(self) -> EnergyLoadForecastProviderNaiveLastHourConfig: + """Convert schema to domain model.""" + return EnergyLoadForecastProviderNaiveLastHourConfig(hours_ahead=self.hours_ahead) + + class Config: + use_enum_values = True + validate_assignment = True + + +class EnergyLoadForecastProviderNaivePersistenceConfigSchema(BaseModel): + """Schema for NaivePersistence EnergyLoadForecastProviderConfig.""" + + hours_ahead: int = Field(default=24, ge=1, le=72, description="Number of hours to forecast ahead") + delta_days: int = Field(default=1, ge=1, le=7, description="Number of days back to use as reference") + + def to_model(self) -> EnergyLoadForecastProviderNaivePersistenceConfig: + """Convert schema to domain model.""" + return EnergyLoadForecastProviderNaivePersistenceConfig( + hours_ahead=self.hours_ahead, + delta_days=self.delta_days, + ) + + class Config: + use_enum_values = True + validate_assignment = True + + +class EnergyLoadForecastProviderSeasonalBaselineConfigSchema(BaseModel): + """Schema for SeasonalBaseline EnergyLoadForecastProviderConfig.""" + + hours_ahead: int = Field(default=3, ge=1, le=72, description="Number of hours to forecast ahead") + weeks_lookback: int = Field(default=4, ge=1, le=52, description="Number of weeks of history to use for profiling") + + def to_model(self) -> EnergyLoadForecastProviderSeasonalBaselineConfig: + """Convert schema to domain model.""" + return EnergyLoadForecastProviderSeasonalBaselineConfig( + hours_ahead=self.hours_ahead, + weeks_lookback=self.weeks_lookback, + ) + + class Config: + use_enum_values = True + validate_assignment = True + + +class EnergyLoadForecastProviderTypicalProfileConfigSchema(BaseModel): + """Schema for TypicalProfile EnergyLoadForecastProviderConfig.""" + + hours_ahead: int = Field(default=24, ge=1, le=72, description="Number of hours to forecast ahead") + weeks_lookback: int = Field( + default=8, ge=1, le=52, description="Weeks of history to build the typical profile from" + ) + + def to_model(self) -> EnergyLoadForecastProviderTypicalProfileConfig: + """Convert schema to domain model.""" + return EnergyLoadForecastProviderTypicalProfileConfig( + hours_ahead=self.hours_ahead, + weeks_lookback=self.weeks_lookback, + ) + + class Config: + use_enum_values = True + validate_assignment = True + + +class EnergyLoadForecastProviderSkforecastConfigSchema(BaseModel): + """Schema for Skforecast EnergyLoadForecastProviderConfig.""" + + hours_ahead: int = Field(default=24, ge=1, le=72, description="Number of hours to forecast ahead") + weeks_lookback: int = Field(default=8, ge=1, le=52, description="Weeks of history for training") + sklearn_model: str = Field( + default="RandomForestRegressor", + description="Name of the sklearn regressor class to use as backend", + ) + num_lags: int = Field(default=72, ge=6, le=336, description="Number of lag features (hours)") + + def to_model(self) -> EnergyLoadForecastProviderSkforecastConfig: + """Convert schema to domain model.""" + return EnergyLoadForecastProviderSkforecastConfig( + hours_ahead=self.hours_ahead, + weeks_lookback=self.weeks_lookback, + sklearn_model=self.sklearn_model, + num_lags=self.num_lags, + ) + + class Config: + use_enum_values = True + validate_assignment = True + + +class EnergyLoadForecastProviderStatsmodelsConfigSchema(BaseModel): + """Schema for Statsmodels EnergyLoadForecastProviderConfig.""" + + hours_ahead: int = Field(default=3, ge=1, le=72, description="Number of hours to forecast ahead") + weeks_lookback: int = Field(default=8, ge=1, le=52, description="Weeks of history for training") + method: str = Field(default="hw", description="Statsmodels method: 'hw' (Holt-Winters) or 'sarima'") + seasonal_periods: int = Field(default=24, ge=1, le=168, description="Hours in a seasonal cycle") + + def to_model(self) -> EnergyLoadForecastProviderStatsmodelsConfig: + """Convert schema to domain model.""" + return EnergyLoadForecastProviderStatsmodelsConfig( + hours_ahead=self.hours_ahead, + weeks_lookback=self.weeks_lookback, + method=self.method, + seasonal_periods=self.seasonal_periods, + ) + + class Config: + use_enum_values = True + validate_assignment = True + + +class EnergyLoadForecastProviderXGBoostConfigSchema(BaseModel): + """Schema for XGBoost EnergyLoadForecastProviderConfig.""" + + hours_ahead: int = Field(default=3, ge=1, le=72, description="Number of hours to forecast ahead") + weeks_lookback: int = Field(default=8, ge=1, le=52, description="Weeks of history for training") + n_estimators: int = Field(default=100, ge=10, le=1000, description="Number of boosting rounds") + max_depth: int = Field(default=6, ge=1, le=15, description="Maximum tree depth") + learning_rate: float = Field(default=0.1, gt=0.0, le=1.0, description="Learning rate") + + def to_model(self) -> EnergyLoadForecastProviderXGBoostConfig: + """Convert schema to domain model.""" + return EnergyLoadForecastProviderXGBoostConfig( + hours_ahead=self.hours_ahead, + weeks_lookback=self.weeks_lookback, + n_estimators=self.n_estimators, + max_depth=self.max_depth, + learning_rate=self.learning_rate, + ) + + class Config: + use_enum_values = True + validate_assignment = True + + +ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_SCHEMA_MAP: Dict[ + type[EnergyLoadForecastProviderConfig], + Union[ + type[EnergyLoadForecastProviderDummyConfigSchema], + type[EnergyLoadForecastProviderNaiveLastHourConfigSchema], + type[EnergyLoadForecastProviderNaivePersistenceConfigSchema], + type[EnergyLoadForecastProviderSeasonalBaselineConfigSchema], + type[EnergyLoadForecastProviderSkforecastConfigSchema], + type[EnergyLoadForecastProviderStatsmodelsConfigSchema], + type[EnergyLoadForecastProviderTypicalProfileConfigSchema], + type[EnergyLoadForecastProviderXGBoostConfigSchema], + ], +] = { + EnergyLoadForecastProviderDummyConfig: EnergyLoadForecastProviderDummyConfigSchema, + EnergyLoadForecastProviderNaiveLastHourConfig: EnergyLoadForecastProviderNaiveLastHourConfigSchema, + EnergyLoadForecastProviderNaivePersistenceConfig: EnergyLoadForecastProviderNaivePersistenceConfigSchema, + EnergyLoadForecastProviderSeasonalBaselineConfig: EnergyLoadForecastProviderSeasonalBaselineConfigSchema, + EnergyLoadForecastProviderSkforecastConfig: EnergyLoadForecastProviderSkforecastConfigSchema, + EnergyLoadForecastProviderStatsmodelsConfig: EnergyLoadForecastProviderStatsmodelsConfigSchema, + EnergyLoadForecastProviderTypicalProfileConfig: EnergyLoadForecastProviderTypicalProfileConfigSchema, + EnergyLoadForecastProviderXGBoostConfig: EnergyLoadForecastProviderXGBoostConfigSchema, +} + + +# --- Energy Load History Provider Schemas --- + + +class EnergyLoadHistoryProviderSchema(BaseModel): + """Schema for EnergyLoadHistoryProvider entity.""" + + id: str = Field(..., description="Unique identifier for the energy load history provider") + name: str = Field(default="", description="Energy load history provider name") + adapter_type: EnergyLoadHistoryProviderAdapter = Field( + default=EnergyLoadHistoryProviderAdapter.DUMMY, + description="Type of energy load history provider adapter", + ) + config: Optional[dict] = Field(default=None, description="Energy load history provider configuration") + external_service_id: Optional[str] = Field(default=None, description="ID of external service") + + @field_validator("id") + @classmethod + def validate_id(cls, v: str) -> str: + """Validate that id is a valid UUID string.""" + try: + uuid.UUID(v) + return v + except ValueError as e: + raise ValueError(f"Invalid UUID format: {v}") from e + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate provider name.""" + if not v.strip(): + raise ValueError("Provider name cannot be empty") + return v.strip() + + @field_validator("adapter_type") + @classmethod + def validate_adapter_type(cls, v: str) -> EnergyLoadHistoryProviderAdapter: + """Validate that adapter_type is a recognized EnergyLoadHistoryProviderAdapter.""" + adapter_values = [adapter.value for adapter in EnergyLoadHistoryProviderAdapter] + if v not in adapter_values: + raise ValueError(f"adapter_type must be one of {adapter_values}") + return EnergyLoadHistoryProviderAdapter(v) + + @field_validator("external_service_id") + @classmethod + def validate_external_service_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that external_service_id is a valid UUID string if provided.""" + if v is not None: + try: + uuid.UUID(v) + except ValueError as exc: + raise ValueError("external_service_id must be a valid UUID string") from exc return v - def to_model(self) -> HomeForecastProviderDummyConfig: - """Convert schema to HomeForecastProviderDummyConfig adapter configuration model instance.""" - return HomeForecastProviderDummyConfig( - load_power_max=self.load_power_max, + @classmethod + def from_model(cls, provider: EnergyLoadHistoryProvider) -> "EnergyLoadHistoryProviderSchema": + """Create schema from domain model.""" + config_dict = None + if provider.config: + config_dict = provider.config.to_dict() + + return cls( + id=str(provider.id), + name=provider.name, + adapter_type=provider.adapter_type, + config=config_dict, + external_service_id=str(provider.external_service_id) if provider.external_service_id else None, ) + @field_serializer("id") + def serialize_id(self, value: str) -> str: + """Serialize id field.""" + return value + + @field_serializer("external_service_id") + def serialize_external_service_id(self, value: Optional[str]) -> Optional[str]: + """Serialize external service id field.""" + return value + + def to_model(self) -> EnergyLoadHistoryProvider: + """Convert schema to domain model.""" + configuration: Optional[EnergyLoadHistoryProviderConfig] = None + if self.config: + config_type = ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP.get(self.adapter_type) + if config_type: + configuration = cast(EnergyLoadHistoryProviderConfig, config_type.from_dict(self.config)) + + return EnergyLoadHistoryProvider( + id=EntityId(uuid.UUID(self.id)), + name=self.name, + adapter_type=self.adapter_type, + config=configuration, + external_service_id=EntityId(uuid.UUID(self.external_service_id)) if self.external_service_id else None, + ) + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + validate_assignment = True + arbitrary_types_allowed = True + json_encoders = { + uuid.UUID: str, + EnergyLoadHistoryProviderAdapter: lambda v: v.value, + } + + +class EnergyLoadHistoryProviderCreateSchema(BaseModel): + """Schema for creating a new energy load history provider.""" + + name: str = Field(default="", description="Energy load history provider name") + adapter_type: EnergyLoadHistoryProviderAdapter = Field( + default=EnergyLoadHistoryProviderAdapter.DUMMY, + description="Type of energy load history provider adapter", + ) + config: Optional[dict] = Field(default=None, description="Energy load history provider configuration") + external_service_id: Optional[str] = Field(default=None, description="ID of external service") + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate provider name.""" + if not v.strip(): + raise ValueError("Provider name cannot be empty") + return v.strip() + + @field_validator("adapter_type") + @classmethod + def validate_adapter_type(cls, v: str) -> EnergyLoadHistoryProviderAdapter: + """Validate that adapter_type is a recognized EnergyLoadHistoryProviderAdapter.""" + adapter_values = [adapter.value for adapter in EnergyLoadHistoryProviderAdapter] + if v not in adapter_values: + raise ValueError(f"adapter_type must be one of {adapter_values}") + return EnergyLoadHistoryProviderAdapter(v) + + @field_validator("external_service_id") + @classmethod + def validate_external_service_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that external_service_id is a valid UUID string if provided.""" + if v is not None: + try: + uuid.UUID(v) + except ValueError as exc: + raise ValueError("external_service_id must be a valid UUID string") from exc + return v + + def to_model(self) -> EnergyLoadHistoryProvider: + """Convert schema to domain model.""" + configuration: Optional[EnergyLoadHistoryProviderConfig] = None + if self.config: + config_type = ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP.get(self.adapter_type) + if config_type: + configuration = cast(EnergyLoadHistoryProviderConfig, config_type.from_dict(self.config)) + + return EnergyLoadHistoryProvider( + id=EntityId(uuid.uuid4()), + name=self.name, + adapter_type=self.adapter_type, + config=configuration, + external_service_id=EntityId(uuid.UUID(self.external_service_id)) if self.external_service_id else None, + ) + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + validate_assignment = True + json_encoders = { + uuid.UUID: str, + EnergyLoadHistoryProviderAdapter: lambda v: v.value, + } + + +class EnergyLoadHistoryProviderUpdateSchema(BaseModel): + """Schema for updating an existing energy load history provider.""" + + name: str = Field(default="", description="Energy load history provider name") + config: Optional[dict] = Field(default=None, description="Energy load history provider configuration") + external_service_id: Optional[str] = Field(default=None, description="ID of external service") + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """Validate provider name.""" + if not v.strip(): + raise ValueError("Provider name cannot be empty") + return v.strip() + + @field_validator("external_service_id") + @classmethod + def validate_external_service_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that external_service_id is a valid UUID string if provided.""" + if v is not None: + try: + uuid.UUID(v) + except ValueError as exc: + raise ValueError("external_service_id must be a valid UUID string") from exc + return v + class Config: """Pydantic configuration.""" @@ -275,9 +1096,89 @@ class Config: validate_assignment = True -HOME_FORECAST_PROVIDER_CONFIG_SCHEMA_MAP: Dict[ - type[HomeForecastProviderConfig], - Union[type[HomeForecastProviderDummyConfigSchema]], +class EnergyLoadHistoryProviderHomeAssistantAPIConfigSchema(BaseModel): + """Schema for HomeAssistantAPI EnergyLoadHistoryProviderConfig.""" + + entity_power: str = Field(default="", description="Home Assistant entity ID for power sensor") + unit_power: str = Field(default="W", description="Unit of power measurement") + + @field_validator("entity_power") + @classmethod + def validate_entity_power(cls, v: str) -> str: + """Validate entity_power is not empty.""" + if not v.strip(): + raise ValueError("entity_power cannot be empty") + return v.strip() + + def to_model(self) -> EnergyLoadHistoryProviderHomeAssistantAPIConfig: + """Convert schema to domain model.""" + return EnergyLoadHistoryProviderHomeAssistantAPIConfig( + entity_power=self.entity_power, unit_power=self.unit_power + ) + + class Config: + """Pydantic configuration.""" + + use_enum_values = True + validate_assignment = True + + +ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_SCHEMA_MAP: Dict[ + type[EnergyLoadHistoryProviderConfig], + Union[type[EnergyLoadHistoryProviderHomeAssistantAPIConfigSchema]], ] = { - HomeForecastProviderDummyConfig: HomeForecastProviderDummyConfigSchema, + EnergyLoadHistoryProviderHomeAssistantAPIConfig: EnergyLoadHistoryProviderHomeAssistantAPIConfigSchema, } + + +class HomeLoadPowerPointSchema(BaseModel): + """Schema for HomeLoadPowerPoint value object.""" + + timestamp: datetime = Field(..., description="Measurement timestamp") + power: float = Field(..., description="Power in watts") + + @classmethod + def from_model(cls, point: HomeLoadPowerPoint) -> "HomeLoadPowerPointSchema": + return cls( + timestamp=cast(datetime, point.timestamp), + power=float(point.power), + ) + + +class LoadConsumptionModelSchema(BaseModel): + """Schema for LoadConsumptionModel entity (without serialized model bytes).""" + + id: str = Field(..., description="Model unique identifier") + device_id: Optional[str] = Field(default=None, description="Device this model was trained for") + adapter_type: EnergyLoadForecastProviderAdapter = Field(..., description="ML adapter type") + trained_at: Optional[datetime] = Field(default=None, description="Training timestamp") + mae: Optional[float] = Field(default=None, description="Mean absolute error on holdout") + rmse: Optional[float] = Field(default=None, description="Root mean squared error on holdout") + samples_used: int = Field(default=0, description="Number of training samples") + is_active: bool = Field(default=False, description="Whether the model is currently active") + tuning_params: Optional[dict] = Field(default=None, description="Best hyperparameters from Optuna tuning") + backtest_mae: Optional[float] = Field(default=None, description="MAE from rolling-window backtesting") + backtest_rmse: Optional[float] = Field(default=None, description="RMSE from rolling-window backtesting") + backtest_folds: int = Field(default=0, description="Number of folds used in backtesting") + + @classmethod + def from_model(cls, model: LoadConsumptionModel) -> "LoadConsumptionModelSchema": + return cls( + id=str(model.id), + device_id=str(model.device_id) if model.device_id else None, + adapter_type=model.adapter_type, + trained_at=model.trained_at, + mae=model.mae, + rmse=model.rmse, + samples_used=model.samples_used, + is_active=model.is_active, + tuning_params=model.tuning_params, + backtest_mae=model.backtest_mae, + backtest_rmse=model.backtest_rmse, + backtest_folds=model.backtest_folds, + ) + + class Config: + """Pydantic configuration.""" + + use_enum_values = True diff --git a/edge_mining/adapters/domain/home_load/tables.py b/edge_mining/adapters/domain/home_load/tables.py index 999f4b7..0f67647 100644 --- a/edge_mining/adapters/domain/home_load/tables.py +++ b/edge_mining/adapters/domain/home_load/tables.py @@ -6,14 +6,13 @@ The mappings handle complex objects using SQLAlchemy event listeners and custom types: - LoadDevice dictionaries are serialized to JSON and reconstructed after loading -- HomeForecastProviderConfig is serialized using custom ConfigurationType +- EnergyLoadForecastProviderConfig is serialized using custom ConfigurationType - EntityId value objects are implicitly converted to/from strings All tables and mappings use the shared metadata and mapper registry from the sqlalchemy.registry module, which are available as module-level singletons. -⚠️ DEVELOPER WARNING ⚠️ -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +WARNING - DEVELOPER WARNING ANY SCHEMA CHANGE (adding/removing/modifying tables or columns) REQUIRES an Alembic migration. Do NOT modify this file without creating a migration: @@ -21,63 +20,94 @@ For detailed instructions, see: docs/ALEMBIC_MIGRATIONS.md For a step-by-step example, see: docs/MIGRATION_EXAMPLE.md -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ """ import json import uuid from typing import Any, Optional -from sqlalchemy import JSON, Column, ForeignKey, String, Table, TypeDecorator, event +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + Float, + ForeignKey, + Index, + Integer, + LargeBinary, + String, + Table, + Text, + TypeDecorator, + event, +) from edge_mining.adapters.infrastructure.persistence.sqlalchemy.common import ConfigurationType from edge_mining.adapters.infrastructure.persistence.sqlalchemy.registry import mapper_registry, metadata from edge_mining.domain.common import EntityId from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter -from edge_mining.domain.home_load.entities import HomeForecastProvider, LoadDevice -from edge_mining.domain.home_load.exceptions import HomeForecastProviderConfigurationError -from edge_mining.shared.adapter_maps.home_load import HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP -from edge_mining.shared.interfaces.config import HomeForecastProviderConfig +from edge_mining.domain.home_load.common import ( + EnergyLoadForecastProviderAdapter, + EnergyLoadHistoryProviderAdapter, + LoadDeviceCategory, +) +from edge_mining.domain.home_load.entities import ( + EnergyLoadForecastProvider, + EnergyLoadHistoryProvider, + LoadConsumptionModel, + LoadDevice, +) +from edge_mining.domain.home_load.exceptions import ( + EnergyLoadForecastProviderConfigurationError, + EnergyLoadHistoryProviderConfigurationError, +) +from edge_mining.shared.adapter_maps.home_load import ( + ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP, + ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP, +) +from edge_mining.shared.interfaces.config import EnergyLoadForecastProviderConfig, EnergyLoadHistoryProviderConfig -class HomeForecastProviderConfigType(ConfigurationType): - """SQLAlchemy type for HomeForecastProviderConfig serialization. +class EnergyLoadForecastProviderConfigType(ConfigurationType): + """SQLAlchemy type for EnergyLoadForecastProviderConfig serialization. Inherits from ConfigurationType to handle JSON serialization/deserialization. """ -def _deserialize_home_forecast_provider_config( - adapter_type: HomeForecastProviderAdapter, config_json: str -) -> Optional[HomeForecastProviderConfig]: - """Deserialize JSON string to HomeForecastProviderConfig based on adapter type.""" +def _deserialize_energy_load_forecast_provider_config( + adapter_type: EnergyLoadForecastProviderAdapter, config_json: str +) -> Optional[EnergyLoadForecastProviderConfig]: + """Deserialize JSON string to EnergyLoadForecastProviderConfig based on adapter type.""" if not config_json: return None data: dict = json.loads(config_json) - if adapter_type not in HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP: - raise HomeForecastProviderConfigurationError( - f"Error reading HomeForecastProvider configuration. Invalid type '{adapter_type}'" + if adapter_type not in ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP: + raise EnergyLoadForecastProviderConfigurationError( + f"Error reading EnergyLoadForecastProvider configuration. Invalid type '{adapter_type}'" ) - config_class: Optional[type[HomeForecastProviderConfig]] = HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get(adapter_type) + config_class: Optional[type[EnergyLoadForecastProviderConfig]] = ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP.get( + adapter_type + ) if not config_class: - raise HomeForecastProviderConfigurationError( - f"Error creating HomeForecastProvider configuration. Type '{adapter_type}'" + raise EnergyLoadForecastProviderConfigurationError( + f"Error creating EnergyLoadForecastProvider configuration. Type '{adapter_type}'" ) config_instance = config_class.from_dict(data) - if not isinstance(config_instance, HomeForecastProviderConfig): - raise HomeForecastProviderConfigurationError( - f"Deserialized config is not of type HomeForecastProviderConfig for adapter type {adapter_type}." + if not isinstance(config_instance, EnergyLoadForecastProviderConfig): + raise EnergyLoadForecastProviderConfigurationError( + f"Deserialized config is not of type EnergyLoadForecastProviderConfig for adapter type {adapter_type}." ) return config_instance -@event.listens_for(HomeForecastProvider, "load") -def _receive_home_forecast_provider_load(target: HomeForecastProvider, context) -> None: +@event.listens_for(EnergyLoadForecastProvider, "load") +def _receive_energy_load_forecast_provider_load(target: EnergyLoadForecastProvider, context) -> None: """Event listener that deserializes config after loading from database.""" # Convert id string to EntityId if needed if hasattr(target, "id") and target.id is not None: @@ -92,76 +122,82 @@ def _receive_home_forecast_provider_load(target: HomeForecastProvider, context) # Convert adapter_type string to enum if needed if isinstance(target.adapter_type, str): try: - target.adapter_type = HomeForecastProviderAdapter(target.adapter_type) + target.adapter_type = EnergyLoadForecastProviderAdapter(target.adapter_type) except ValueError: pass if target.config and isinstance(target.config, str): - target.config = _deserialize_home_forecast_provider_config(target.adapter_type, target.config) + target.config = _deserialize_energy_load_forecast_provider_config(target.adapter_type, target.config) -@event.listens_for(HomeForecastProvider, "before_insert") -@event.listens_for(HomeForecastProvider, "before_update") -def _flatten_home_forecast_provider_composites(mapper, connection, target: Any) -> None: +@event.listens_for(EnergyLoadForecastProvider, "before_insert") +@event.listens_for(EnergyLoadForecastProvider, "before_update") +def _flatten_energy_load_forecast_provider_composites(mapper, connection, target: Any) -> None: """Convert enum attributes to primitive values before persisting.""" if hasattr(target, "adapter_type") and target.adapter_type is not None: - if isinstance(target.adapter_type, HomeForecastProviderAdapter): + if isinstance(target.adapter_type, EnergyLoadForecastProviderAdapter): target.adapter_type = target.adapter_type.value -@event.listens_for(HomeForecastProvider, "after_insert") -@event.listens_for(HomeForecastProvider, "after_update") -def _restore_home_forecast_provider_composites(mapper, connection, target: Any) -> None: +@event.listens_for(EnergyLoadForecastProvider, "after_insert") +@event.listens_for(EnergyLoadForecastProvider, "after_update") +def _restore_energy_load_forecast_provider_composites(mapper, connection, target: Any) -> None: """Restore enum attributes after persist operations.""" if hasattr(target, "adapter_type") and target.adapter_type is not None: if isinstance(target.adapter_type, str): try: - target.adapter_type = HomeForecastProviderAdapter(target.adapter_type) + target.adapter_type = EnergyLoadForecastProviderAdapter(target.adapter_type) except ValueError: pass -# Define the home_forecast_providers table using imperative style -home_forecast_providers_table = Table( - "home_forecast_providers", +# Define the energy_load_forecast_providers table using imperative style +energy_load_forecast_providers_table = Table( + "energy_load_forecast_providers", metadata, Column("id", String, primary_key=True, index=True), Column("name", String, nullable=False), Column("adapter_type", String, nullable=False), - Column("config", HomeForecastProviderConfigType, nullable=True), + Column("config", EnergyLoadForecastProviderConfigType, nullable=True), Column("external_service_id", String, ForeignKey("external_services.id"), nullable=True), ) -# Map HomeForecastProvider +# Map EnergyLoadForecastProvider mapper_registry.map_imperatively( - HomeForecastProvider, - home_forecast_providers_table, + EnergyLoadForecastProvider, + energy_load_forecast_providers_table, ) -# Custom TypeDecorator for LoadDevice dictionary serialization +# Custom TypeDecorator for LoadDevice list serialization class LoadDevicesDictType(TypeDecorator): - """Custom type for serializing Dict[EntityId, LoadDevice] to JSON.""" + """Custom type for serializing List[LoadDevice] to a JSON array.""" impl = JSON cache_ok = True def process_bind_param(self, value, dialect): - """Convert Dict[EntityId, LoadDevice] to JSON dict for database storage.""" + """Convert List[LoadDevice] to a JSON list for database storage.""" if value is None: return None - # Convert to dict with string keys and LoadDevice dicts - return { - str(device_id): { + return [ + { "id": str(device.id), "name": device.name, - "type": device.type, + "category": device.category.value, + "enabled": device.enabled, + "energy_load_forecast_provider_id": ( + str(device.energy_load_forecast_provider_id) if device.energy_load_forecast_provider_id else None + ), + "energy_load_history_provider_id": ( + str(device.energy_load_history_provider_id) if device.energy_load_history_provider_id else None + ), } - for device_id, device in value.items() - } + for device in value + ] def process_result_value(self, value, dialect): - """Return raw JSON dict - will be reconstructed in event listener.""" + """Return raw JSON list - will be reconstructed in event listener.""" return value @@ -179,17 +215,29 @@ def process_result_value(self, value, dialect): @event.listens_for(HomeLoadsProfile, "load") def _receive_home_profile_load(target, context): """Reconstruct LoadDevice objects from JSON after loading from database.""" - if target.devices and isinstance(target.devices, dict): - reconstructed_devices = {} - for device_id_str, device_data in target.devices.items(): - if isinstance(device_data, dict): - device = LoadDevice( - id=EntityId(device_data["id"]), + if isinstance(target.id, str): + target.id = EntityId(uuid.UUID(target.id)) + + if target.devices and isinstance(target.devices, list): + reconstructed: list = [] + for device_data in target.devices: + if not isinstance(device_data, dict): + continue + forecast_id = device_data.get("energy_load_forecast_provider_id") + history_id = device_data.get("energy_load_history_provider_id") + reconstructed.append( + LoadDevice( + id=EntityId(uuid.UUID(device_data["id"])), name=device_data["name"], - type=device_data["type"], + category=LoadDeviceCategory(device_data["category"]), + enabled=bool(device_data.get("enabled", True)), + energy_load_forecast_provider_id=(EntityId(uuid.UUID(forecast_id)) if forecast_id else None), + energy_load_history_provider_id=(EntityId(uuid.UUID(history_id)) if history_id else None), ) - reconstructed_devices[EntityId(device_id_str)] = device - target.devices = reconstructed_devices + ) + target.devices = reconstructed + elif target.devices is None: + target.devices = [] # Map HomeLoadsProfile aggregate root to table @@ -202,3 +250,205 @@ def _receive_home_profile_load(target, context): "devices": home_profiles_table.c.devices_json, }, ) + + +# --- EnergyLoadHistoryProvider table + mapping --- + + +class EnergyLoadHistoryProviderConfigType(ConfigurationType): + """SQLAlchemy type for EnergyLoadHistoryProviderConfig serialization.""" + + +def _deserialize_energy_load_history_provider_config( + adapter_type: EnergyLoadHistoryProviderAdapter, config_json: str +) -> Optional[EnergyLoadHistoryProviderConfig]: + """Deserialize JSON string to EnergyLoadHistoryProviderConfig based on adapter type.""" + if not config_json: + return None + + data: dict = json.loads(config_json) + + if adapter_type not in ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP: + raise EnergyLoadHistoryProviderConfigurationError( + f"Error reading EnergyLoadHistoryProvider configuration. Invalid type '{adapter_type}'" + ) + + config_class: Optional[type[EnergyLoadHistoryProviderConfig]] = ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP.get( + adapter_type + ) + if not config_class: + # Some adapters (e.g. DUMMY) have no config + return None + + config_instance = config_class.from_dict(data) + if not isinstance(config_instance, EnergyLoadHistoryProviderConfig): + raise EnergyLoadHistoryProviderConfigurationError( + f"Deserialized config is not of type EnergyLoadHistoryProviderConfig for adapter type {adapter_type}." + ) + return config_instance + + +@event.listens_for(EnergyLoadHistoryProvider, "load") +def _receive_energy_load_history_provider_load(target: EnergyLoadHistoryProvider, context) -> None: + """Event listener that deserializes config after loading from database.""" + if hasattr(target, "id") and target.id is not None: + if isinstance(target.id, str): # type: ignore[arg-type,misc] + target.id = EntityId(uuid.UUID(target.id)) # type: ignore[assignment] + + if hasattr(target, "external_service_id") and target.external_service_id is not None: + if isinstance(target.external_service_id, str): # type: ignore + target.external_service_id = EntityId(uuid.UUID(target.external_service_id)) # type: ignore + + if isinstance(target.adapter_type, str): + try: + target.adapter_type = EnergyLoadHistoryProviderAdapter(target.adapter_type) + except ValueError: + pass + + if target.config and isinstance(target.config, str): + target.config = _deserialize_energy_load_history_provider_config(target.adapter_type, target.config) + + +@event.listens_for(EnergyLoadHistoryProvider, "before_insert") +@event.listens_for(EnergyLoadHistoryProvider, "before_update") +def _flatten_energy_load_history_provider_composites(mapper, connection, target: Any) -> None: + """Convert enum attributes to primitive values before persisting.""" + if hasattr(target, "adapter_type") and target.adapter_type is not None: + if isinstance(target.adapter_type, EnergyLoadHistoryProviderAdapter): + target.adapter_type = target.adapter_type.value + + +@event.listens_for(EnergyLoadHistoryProvider, "after_insert") +@event.listens_for(EnergyLoadHistoryProvider, "after_update") +def _restore_energy_load_history_provider_composites(mapper, connection, target: Any) -> None: + """Restore enum attributes after persist operations.""" + if hasattr(target, "adapter_type") and target.adapter_type is not None: + if isinstance(target.adapter_type, str): + try: + target.adapter_type = EnergyLoadHistoryProviderAdapter(target.adapter_type) + except ValueError: + pass + + +energy_load_history_providers_table = Table( + "energy_load_history_providers", + metadata, + Column("id", String, primary_key=True, index=True), + Column("name", String, nullable=False), + Column("adapter_type", String, nullable=False), + Column("config", EnergyLoadHistoryProviderConfigType, nullable=True), + Column("external_service_id", String, ForeignKey("external_services.id"), nullable=True), +) + +mapper_registry.map_imperatively( + EnergyLoadHistoryProvider, + energy_load_history_providers_table, +) + + +# HomeLoadPowerPoint table (device-scoped time series). +# +# Not imperatively mapped: HomeLoadPowerPoint is a Value Object (frozen +# dataclass) and the SQLAlchemy repository interacts with this table via +# Core (insert/select statements) to keep the domain model pure. +# +# Composite primary key (device_id, timestamp) yields: +# - natural uniqueness per device over time +# - idempotent ingestion (re-fetching the same HA window is a no-op) +# - clustered index on (device_id, timestamp) for O(log n) range scans +home_load_power_points_table = Table( + "home_load_power_points", + metadata, + Column("device_id", String, nullable=False, primary_key=True), + Column("timestamp", DateTime(timezone=True), nullable=False, primary_key=True), + Column("power", Float, nullable=False), + Index("ix_home_load_power_points_device_ts", "device_id", "timestamp"), +) + + +# --- LoadConsumptionModel table + mapping --- +# +# Stores trained ML models (Holt-Winters, XGBoost, etc.) with serialized +# weights in `model_bytes` (LargeBinary / BLOB). The `is_active` flag +# designates the currently promoted model per (adapter_type, device_id) +# combination. + +load_consumption_models_table = Table( + "load_consumption_models", + metadata, + Column("id", String, primary_key=True, index=True), + Column("device_id", String, nullable=True), + Column("adapter_type", String, nullable=False), + Column("trained_at", DateTime(timezone=True), nullable=True), + Column("mae", Float, nullable=True), + Column("rmse", Float, nullable=True), + Column("samples_used", Integer, nullable=False, default=0), + Column("is_active", Boolean, nullable=False, default=False), + Column("model_bytes", LargeBinary, nullable=True), + Column("tuning_params", Text, nullable=True), + Column("backtest_mae", Float, nullable=True), + Column("backtest_rmse", Float, nullable=True), + Column("backtest_folds", Integer, nullable=False, default=0), + Index("ix_load_consumption_models_active", "adapter_type", "device_id", "is_active"), +) + + +@event.listens_for(LoadConsumptionModel, "load") +def _receive_load_consumption_model_load(target: LoadConsumptionModel, context) -> None: + """Reconstruct domain types after loading from database.""" + if hasattr(target, "id") and target.id is not None: + if isinstance(target.id, str): + target.id = EntityId(uuid.UUID(target.id)) + + if hasattr(target, "device_id") and target.device_id is not None: + if isinstance(target.device_id, str): + target.device_id = EntityId(uuid.UUID(target.device_id)) + + if isinstance(target.adapter_type, str): + try: + target.adapter_type = EnergyLoadForecastProviderAdapter(target.adapter_type) + except ValueError: + pass + + if hasattr(target, "tuning_params") and isinstance(target.tuning_params, str): + try: + target.tuning_params = json.loads(target.tuning_params) + except (json.JSONDecodeError, TypeError): + target.tuning_params = None + + +@event.listens_for(LoadConsumptionModel, "before_insert") +@event.listens_for(LoadConsumptionModel, "before_update") +def _flatten_load_consumption_model_composites(mapper, connection, target: Any) -> None: + """Convert enum attributes to primitive values before persisting.""" + if hasattr(target, "adapter_type") and target.adapter_type is not None: + if isinstance(target.adapter_type, EnergyLoadForecastProviderAdapter): + target.adapter_type = target.adapter_type.value + + if hasattr(target, "tuning_params") and target.tuning_params is not None: + if isinstance(target.tuning_params, dict): + target.tuning_params = json.dumps(target.tuning_params) + + +@event.listens_for(LoadConsumptionModel, "after_insert") +@event.listens_for(LoadConsumptionModel, "after_update") +def _restore_load_consumption_model_composites(mapper, connection, target: Any) -> None: + """Restore enum attributes after persist operations.""" + if hasattr(target, "adapter_type") and target.adapter_type is not None: + if isinstance(target.adapter_type, str): + try: + target.adapter_type = EnergyLoadForecastProviderAdapter(target.adapter_type) + except ValueError: + pass + + if hasattr(target, "tuning_params") and isinstance(target.tuning_params, str): + try: + target.tuning_params = json.loads(target.tuning_params) + except (json.JSONDecodeError, TypeError): + target.tuning_params = None + + +mapper_registry.map_imperatively( + LoadConsumptionModel, + load_consumption_models_table, +) diff --git a/edge_mining/adapters/domain/optimization_unit/cli/commands.py b/edge_mining/adapters/domain/optimization_unit/cli/commands.py index 3d22625..03442f7 100644 --- a/edge_mining/adapters/domain/optimization_unit/cli/commands.py +++ b/edge_mining/adapters/domain/optimization_unit/cli/commands.py @@ -74,7 +74,6 @@ def handle_add_optimization_unit(configuration_service: ConfigurationServiceInte selected_notifiers = [selected_notifiers] # To be implemented in the next release - home_forecast_provider_id = None performance_tracker_id = None try: @@ -88,7 +87,6 @@ def handle_add_optimization_unit(configuration_service: ConfigurationServiceInte energy_source_id=selected_energy_source.id if selected_energy_source else None, target_miner_ids=target_miner_ids, policy_id=selected_policy.id if selected_policy else None, - home_forecast_provider_id=home_forecast_provider_id, performance_tracker_id=performance_tracker_id, notifier_ids=notifier_ids, ) @@ -288,7 +286,6 @@ def update_optimization_unit( new_optimization_unit.target_miner_ids = optimization_unit.target_miner_ids new_optimization_unit.policy_id = optimization_unit.policy_id new_optimization_unit.notifier_ids = optimization_unit.notifier_ids - new_optimization_unit.home_forecast_provider_id = optimization_unit.home_forecast_provider_id new_optimization_unit.performance_tracker_id = optimization_unit.performance_tracker_id click.echo("\nDo you want to change the energy source?") @@ -351,7 +348,6 @@ def update_optimization_unit( energy_source_id=new_optimization_unit.energy_source_id, target_miner_ids=new_optimization_unit.target_miner_ids, policy_id=new_optimization_unit.policy_id, - home_forecast_provider_id=new_optimization_unit.home_forecast_provider_id, performance_tracker_id=new_optimization_unit.performance_tracker_id, notifier_ids=new_optimization_unit.notifier_ids, is_enabled=new_optimization_unit.is_enabled, diff --git a/edge_mining/adapters/domain/optimization_unit/fast_api/router.py b/edge_mining/adapters/domain/optimization_unit/fast_api/router.py index 60926b1..de7957e 100644 --- a/edge_mining/adapters/domain/optimization_unit/fast_api/router.py +++ b/edge_mining/adapters/domain/optimization_unit/fast_api/router.py @@ -62,8 +62,8 @@ async def add_optimization_unit( policy_id=optimization_unit_to_add.policy_id, target_miner_ids=optimization_unit_to_add.target_miner_ids, energy_source_id=optimization_unit_to_add.energy_source_id, - home_forecast_provider_id=optimization_unit_to_add.home_forecast_provider_id, performance_tracker_id=optimization_unit_to_add.performance_tracker_id, + home_loads_profile_id=optimization_unit_to_add.home_loads_profile, notifier_ids=optimization_unit_to_add.notifier_ids, ) @@ -127,10 +127,6 @@ async def update_optimization_unit( if optimization_unit_update.energy_source_id: energy_source_id = EntityId(uuid.UUID(optimization_unit_update.energy_source_id)) - home_forecast_provider_id: Optional[EntityId] = None - if optimization_unit_update.home_forecast_provider_id: - home_forecast_provider_id = EntityId(uuid.UUID(optimization_unit_update.home_forecast_provider_id)) - performance_tracker_id: Optional[EntityId] = None if optimization_unit_update.performance_tracker_id: performance_tracker_id = EntityId(uuid.UUID(optimization_unit_update.performance_tracker_id)) @@ -139,6 +135,10 @@ async def update_optimization_unit( if optimization_unit_update.notifier_ids: notifier_ids = [EntityId(uuid.UUID(notifier_id)) for notifier_id in optimization_unit_update.notifier_ids] + home_loads_profile_id: Optional[EntityId] = None + if optimization_unit_update.home_loads_profile_id: + home_loads_profile_id = EntityId(uuid.UUID(optimization_unit_update.home_loads_profile_id)) + # Update the optimization unit updated_unit = await config_service.update_optimization_unit( unit_id=unit_id, @@ -147,8 +147,8 @@ async def update_optimization_unit( policy_id=policy_id, target_miner_ids=target_miner_ids, energy_source_id=energy_source_id, - home_forecast_provider_id=home_forecast_provider_id, performance_tracker_id=performance_tracker_id, + home_loads_profile_id=home_loads_profile_id, notifier_ids=notifier_ids, ) @@ -347,6 +347,32 @@ async def remove_target_miner( raise HTTPException(status_code=500, detail=str(e)) from e +@router.post("/optimization-units/{unit_id}/home-loads-profile", response_model=EnergyOptimizationUnitSchema) +async def assign_home_loads_profile( + unit_id: EntityId, + home_loads_profile_id: Optional[EntityId] = None, + config_service: Annotated[ConfigurationServiceInterface, Depends(get_config_service)] = None, +) -> EnergyOptimizationUnitSchema: + """Assign a home loads profile to an optimization unit.""" + try: + optimization_unit = config_service.get_optimization_unit(unit_id) + + if optimization_unit is None: + raise OptimizationUnitNotFoundError(f"Optimization Unit with ID {unit_id} not found") + + updated_unit = await config_service.assign_home_loads_profile_to_optimization_unit( + unit_id, home_loads_profile_id + ) + + response = EnergyOptimizationUnitSchema.from_model(updated_unit) + + return response + except OptimizationUnitNotFoundError as e: + raise HTTPException(status_code=404, detail="Optimization Unit not found") from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + @router.post("/optimization-units/{unit_id}/notifiers", response_model=EnergyOptimizationUnitSchema) async def assign_notifiers( unit_id: EntityId, diff --git a/edge_mining/adapters/domain/optimization_unit/repositories.py b/edge_mining/adapters/domain/optimization_unit/repositories.py index 486ecc7..f8592ae 100644 --- a/edge_mining/adapters/domain/optimization_unit/repositories.py +++ b/edge_mining/adapters/domain/optimization_unit/repositories.py @@ -85,8 +85,8 @@ def _create_tables(self): policy_id TEXT, target_miner_ids TEXT, -- JSON list of MinerId strings energy_source_id TEXT, - home_forecast_provider_id TEXT, performance_tracker_id TEXT, + home_loads_profile_id TEXT, notifier_ids TEXT -- JSON list of NotifierId strings ); """ @@ -128,12 +128,14 @@ def _row_to_optimization_unit(self, row: sqlite3.Row) -> Optional[EnergyOptimiza policy_id=(EntityId(row["policy_id"]) if row["policy_id"] else None), target_miner_ids=target_miner_ids, energy_source_id=(EntityId(row["energy_source_id"]) if row["energy_source_id"] else None), - home_forecast_provider_id=( - EntityId(row["home_forecast_provider_id"]) if row["home_forecast_provider_id"] else None - ), performance_tracker_id=( EntityId(row["performance_tracker_id"]) if row["performance_tracker_id"] else None ), + home_loads_profile=( + EntityId(row["home_loads_profile_id"]) + if row.keys().__contains__("home_loads_profile_id") and row["home_loads_profile_id"] + else None + ), notifier_ids=notifier_ids, ) except (ValueError, KeyError) as e: @@ -145,7 +147,7 @@ def add(self, optimization_unit: EnergyOptimizationUnit) -> None: self.logger.debug(f"Adding optimization unit {optimization_unit.id} to SQLite.") sql = """ INSERT INTO optimization_units (id, name, description, is_enabled, policy_id, target_miner_ids, - energy_source_id, home_forecast_provider_id, performance_tracker_id, notifier_ids) + energy_source_id, performance_tracker_id, home_loads_profile_id, notifier_ids) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ conn = self._db.get_connection() @@ -165,8 +167,8 @@ def add(self, optimization_unit: EnergyOptimizationUnit) -> None: optimization_unit.policy_id, target_ids_json, optimization_unit.energy_source_id, - optimization_unit.home_forecast_provider_id, optimization_unit.performance_tracker_id, + str(optimization_unit.home_loads_profile) if optimization_unit.home_loads_profile else None, notifier_ids_json, ), ) @@ -250,7 +252,7 @@ def update(self, optimization_unit: EnergyOptimizationUnit) -> None: sql = """ UPDATE optimization_units SET name = ?, description = ?, is_enabled = ?, policy_id = ?, target_miner_ids = ?, energy_source_id = ?, - home_forecast_provider_id = ?, performance_tracker_id = ?, notifier_ids = ? + performance_tracker_id = ?, home_loads_profile_id = ?, notifier_ids = ? WHERE id = ? """ conn = self._db.get_connection() @@ -270,8 +272,8 @@ def update(self, optimization_unit: EnergyOptimizationUnit) -> None: optimization_unit.policy_id, target_ids_json, optimization_unit.energy_source_id, - optimization_unit.home_forecast_provider_id, optimization_unit.performance_tracker_id, + str(optimization_unit.home_loads_profile) if optimization_unit.home_loads_profile else None, notifier_ids_json, optimization_unit.id, ), @@ -384,8 +386,8 @@ def update(self, optimization_unit: EnergyOptimizationUnit) -> None: existing_entity.policy_id = optimization_unit.policy_id existing_entity.target_miner_ids = optimization_unit.target_miner_ids existing_entity.energy_source_id = optimization_unit.energy_source_id - existing_entity.home_forecast_provider_id = optimization_unit.home_forecast_provider_id existing_entity.performance_tracker_id = optimization_unit.performance_tracker_id + existing_entity.home_loads_profile = optimization_unit.home_loads_profile existing_entity.notifier_ids = optimization_unit.notifier_ids session.commit() diff --git a/edge_mining/adapters/domain/optimization_unit/schemas.py b/edge_mining/adapters/domain/optimization_unit/schemas.py index 93304bf..85d1228 100644 --- a/edge_mining/adapters/domain/optimization_unit/schemas.py +++ b/edge_mining/adapters/domain/optimization_unit/schemas.py @@ -19,10 +19,8 @@ class EnergyOptimizationUnitSchema(BaseModel): policy_id: Optional[str] = Field(default=None, description="ID of the policy to be used for optimization") target_miner_ids: List[str] = Field(default_factory=list, description="List of target miner IDs to be controlled") energy_source_id: Optional[str] = Field(default=None, description="ID of the energy source to be used") - home_forecast_provider_id: Optional[str] = Field( - default=None, description="ID of the home load forecast provider to be used" - ) performance_tracker_id: Optional[str] = Field(default=None, description="ID of the performance tracker to be used") + home_loads_profile_id: Optional[str] = Field(default=None, description="ID of the home loads profile to be used") notifier_ids: List[str] = Field(default_factory=list, description="List of notifier IDs to be used") @field_validator("id") @@ -87,26 +85,26 @@ def validate_energy_source_id(cls, v: Optional[str]) -> Optional[str]: raise ValueError("energy_source_id must be a valid UUID string") from exc return v - @field_validator("home_forecast_provider_id") + @field_validator("performance_tracker_id") @classmethod - def validate_home_forecast_provider_id(cls, v: Optional[str]) -> Optional[str]: - """Validate that home_forecast_provider_id is a valid UUID string if provided.""" + def validate_performance_tracker_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that performance_tracker_id is a valid UUID string if provided.""" if v is not None: try: uuid.UUID(v) except ValueError as exc: - raise ValueError("home_forecast_provider_id must be a valid UUID string") from exc + raise ValueError("performance_tracker_id must be a valid UUID string") from exc return v - @field_validator("performance_tracker_id") + @field_validator("home_loads_profile_id") @classmethod - def validate_performance_tracker_id(cls, v: Optional[str]) -> Optional[str]: - """Validate that performance_tracker_id is a valid UUID string if provided.""" + def validate_home_loads_profile_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that home_loads_profile_id is a valid UUID string if provided.""" if v is not None: try: uuid.UUID(v) except ValueError as exc: - raise ValueError("performance_tracker_id must be a valid UUID string") from exc + raise ValueError("home_loads_profile_id must be a valid UUID string") from exc return v @field_validator("notifier_ids") @@ -131,14 +129,12 @@ def from_model(cls, optimization_unit: EnergyOptimizationUnit) -> "EnergyOptimiz policy_id=str(optimization_unit.policy_id) if optimization_unit.policy_id else None, target_miner_ids=[str(miner_id) for miner_id in optimization_unit.target_miner_ids], energy_source_id=str(optimization_unit.energy_source_id) if optimization_unit.energy_source_id else None, - home_forecast_provider_id=( - str(optimization_unit.home_forecast_provider_id) - if optimization_unit.home_forecast_provider_id - else None - ), performance_tracker_id=( str(optimization_unit.performance_tracker_id) if optimization_unit.performance_tracker_id else None ), + home_loads_profile_id=( + str(optimization_unit.home_loads_profile) if optimization_unit.home_loads_profile else None + ), notifier_ids=[str(notifier_id) for notifier_id in optimization_unit.notifier_ids], ) @@ -162,16 +158,16 @@ def serialize_energy_source_id(self, value: Optional[str]) -> Optional[str]: """Serialize energy_source_id field.""" return str(value) if value is not None else None - @field_serializer("home_forecast_provider_id") - def serialize_home_forecast_provider_id(self, value: Optional[str]) -> Optional[str]: - """Serialize home_forecast_provider_id field.""" - return str(value) if value is not None else None - @field_serializer("performance_tracker_id") def serialize_performance_tracker_id(self, value: Optional[str]) -> Optional[str]: """Serialize performance_tracker_id field.""" return str(value) if value is not None else None + @field_serializer("home_loads_profile_id") + def serialize_home_loads_profile_id(self, value: Optional[str]) -> Optional[str]: + """Serialize home_loads_profile_id field.""" + return str(value) if value is not None else None + @field_serializer("notifier_ids") def serialize_notifier_ids(self, value: List[str]) -> List[str]: """Serialize notifier_ids field.""" @@ -187,12 +183,12 @@ def to_model(self) -> EnergyOptimizationUnit: policy_id=EntityId(uuid.UUID(self.policy_id)) if self.policy_id else None, target_miner_ids=[EntityId(uuid.UUID(miner_id)) for miner_id in self.target_miner_ids], energy_source_id=EntityId(uuid.UUID(self.energy_source_id)) if self.energy_source_id else None, - home_forecast_provider_id=( - EntityId(uuid.UUID(self.home_forecast_provider_id)) if self.home_forecast_provider_id else None - ), performance_tracker_id=( EntityId(uuid.UUID(self.performance_tracker_id)) if self.performance_tracker_id else None ), + home_loads_profile=( + EntityId(uuid.UUID(self.home_loads_profile_id)) if self.home_loads_profile_id else None + ), notifier_ids=[EntityId(uuid.UUID(notifier_id)) for notifier_id in self.notifier_ids], ) @@ -215,10 +211,8 @@ class EnergyOptimizationUnitCreateSchema(BaseModel): policy_id: Optional[str] = Field(default=None, description="ID of the policy to be used for optimization") target_miner_ids: List[str] = Field(default_factory=list, description="List of target miner IDs to be controlled") energy_source_id: Optional[str] = Field(default=None, description="ID of the energy source to be used") - home_forecast_provider_id: Optional[str] = Field( - default=None, description="ID of the home load forecast provider to be used" - ) performance_tracker_id: Optional[str] = Field(default=None, description="ID of the performance tracker to be used") + home_loads_profile_id: Optional[str] = Field(default=None, description="ID of the home loads profile to be used") notifier_ids: List[str] = Field(default_factory=list, description="List of notifier IDs to be used") @field_validator("name") @@ -273,26 +267,26 @@ def validate_energy_source_id(cls, v: Optional[str]) -> Optional[str]: raise ValueError("energy_source_id must be a valid UUID string") from exc return v - @field_validator("home_forecast_provider_id") + @field_validator("performance_tracker_id") @classmethod - def validate_home_forecast_provider_id(cls, v: Optional[str]) -> Optional[str]: - """Validate that home_forecast_provider_id is a valid UUID string if provided.""" + def validate_performance_tracker_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that performance_tracker_id is a valid UUID string if provided.""" if v is not None: try: uuid.UUID(v) except ValueError as exc: - raise ValueError("home_forecast_provider_id must be a valid UUID string") from exc + raise ValueError("performance_tracker_id must be a valid UUID string") from exc return v - @field_validator("performance_tracker_id") + @field_validator("home_loads_profile_id") @classmethod - def validate_performance_tracker_id(cls, v: Optional[str]) -> Optional[str]: - """Validate that performance_tracker_id is a valid UUID string if provided.""" + def validate_home_loads_profile_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that home_loads_profile_id is a valid UUID string if provided.""" if v is not None: try: uuid.UUID(v) except ValueError as exc: - raise ValueError("performance_tracker_id must be a valid UUID string") from exc + raise ValueError("home_loads_profile_id must be a valid UUID string") from exc return v @field_validator("notifier_ids") @@ -316,12 +310,12 @@ def to_model(self) -> EnergyOptimizationUnit: policy_id=EntityId(uuid.UUID(self.policy_id)) if self.policy_id else None, target_miner_ids=[EntityId(uuid.UUID(miner_id)) for miner_id in self.target_miner_ids], energy_source_id=EntityId(uuid.UUID(self.energy_source_id)) if self.energy_source_id else None, - home_forecast_provider_id=( - EntityId(uuid.UUID(self.home_forecast_provider_id)) if self.home_forecast_provider_id else None - ), performance_tracker_id=( EntityId(uuid.UUID(self.performance_tracker_id)) if self.performance_tracker_id else None ), + home_loads_profile=( + EntityId(uuid.UUID(self.home_loads_profile_id)) if self.home_loads_profile_id else None + ), notifier_ids=[EntityId(uuid.UUID(notifier_id)) for notifier_id in self.notifier_ids], ) @@ -343,10 +337,8 @@ class EnergyOptimizationUnitUpdateSchema(BaseModel): policy_id: Optional[str] = Field(default=None, description="ID of the policy to be used for optimization") target_miner_ids: List[str] = Field(default_factory=list, description="List of target miner IDs to be controlled") energy_source_id: Optional[str] = Field(default=None, description="ID of the energy source to be used") - home_forecast_provider_id: Optional[str] = Field( - default=None, description="ID of the home load forecast provider to be used" - ) performance_tracker_id: Optional[str] = Field(default=None, description="ID of the performance tracker to be used") + home_loads_profile_id: Optional[str] = Field(default=None, description="ID of the home loads profile to be used") notifier_ids: List[str] = Field(default_factory=list, description="List of notifier IDs to be used") @field_validator("name") @@ -401,26 +393,26 @@ def validate_energy_source_id(cls, v: Optional[str]) -> Optional[str]: raise ValueError("energy_source_id must be a valid UUID string") from exc return v - @field_validator("home_forecast_provider_id") + @field_validator("performance_tracker_id") @classmethod - def validate_home_forecast_provider_id(cls, v: Optional[str]) -> Optional[str]: - """Validate that home_forecast_provider_id is a valid UUID string if provided.""" + def validate_performance_tracker_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that performance_tracker_id is a valid UUID string if provided.""" if v is not None: try: uuid.UUID(v) except ValueError as exc: - raise ValueError("home_forecast_provider_id must be a valid UUID string") from exc + raise ValueError("performance_tracker_id must be a valid UUID string") from exc return v - @field_validator("performance_tracker_id") + @field_validator("home_loads_profile_id") @classmethod - def validate_performance_tracker_id(cls, v: Optional[str]) -> Optional[str]: - """Validate that performance_tracker_id is a valid UUID string if provided.""" + def validate_home_loads_profile_id(cls, v: Optional[str]) -> Optional[str]: + """Validate that home_loads_profile_id is a valid UUID string if provided.""" if v is not None: try: uuid.UUID(v) except ValueError as exc: - raise ValueError("performance_tracker_id must be a valid UUID string") from exc + raise ValueError("home_loads_profile_id must be a valid UUID string") from exc return v @field_validator("notifier_ids") diff --git a/edge_mining/adapters/domain/optimization_unit/tables.py b/edge_mining/adapters/domain/optimization_unit/tables.py index f42e475..8a47c60 100644 --- a/edge_mining/adapters/domain/optimization_unit/tables.py +++ b/edge_mining/adapters/domain/optimization_unit/tables.py @@ -63,8 +63,8 @@ def process_result_value(self, value, dialect) -> List[EntityId]: Column("policy_id", String, nullable=True), # TODO: Add ForeignKey when policies table exists Column("target_miner_ids", EntityIdListType, nullable=False), # JSON list - could be association table Column("energy_source_id", String, ForeignKey("energy_sources.id"), nullable=True), - Column("home_forecast_provider_id", String, ForeignKey("home_forecast_providers.id"), nullable=True), Column("performance_tracker_id", String, ForeignKey("mining_performance_trackers.id"), nullable=True), + Column("home_loads_profile", String, nullable=True), Column("notifier_ids", EntityIdListType, nullable=False), # JSON list - could be association table ) diff --git a/edge_mining/adapters/domain/policy/schemas.py b/edge_mining/adapters/domain/policy/schemas.py index 5e7c4da..fb8684b 100644 --- a/edge_mining/adapters/domain/policy/schemas.py +++ b/edge_mining/adapters/domain/policy/schemas.py @@ -8,7 +8,7 @@ from edge_mining.adapters.domain.energy.schemas import EnergySourceSchema, EnergyStateSnapshotSchema from edge_mining.adapters.domain.forecast.schemas import ForecastSchema, SunSchema -from edge_mining.adapters.domain.home_load.schemas import ConsumptionForecastSchema +from edge_mining.adapters.domain.home_load.schemas import HomeLoadsConsumptionSchema from edge_mining.adapters.domain.miner.schemas import MinerSchema, MinerStateSnapshotSchema from edge_mining.adapters.domain.performance.schemas import MiningPerformanceSnapshotSchema from edge_mining.adapters.domain.policy.utils import FieldStructureSchema, _extract_schema_structure @@ -464,7 +464,9 @@ class DecisionalContextSchema(BaseModel): energy_source: Optional[EnergySourceSchema] = Field(None, description="Energy source information") energy_state: Optional[EnergyStateSnapshotSchema] = Field(None, description="Current energy state snapshot") forecast: Optional[ForecastSchema] = Field(None, description="Energy production forecast") - home_load_forecast: Optional[ConsumptionForecastSchema] = Field(None, description="Home consumption forecast") + home_load: Optional[HomeLoadsConsumptionSchema] = Field( + None, description="Household consumption (per-device history + forecast + totals)" + ) mining_performance: Optional[MiningPerformanceSnapshotSchema] = Field( None, description="Consolidated mining performance snapshot from the pool" ) @@ -502,9 +504,7 @@ def from_model(cls, context: DecisionalContext) -> "DecisionalContextSchema": energy_source=EnergySourceSchema.from_model(context.energy_source) if context.energy_source else None, energy_state=EnergyStateSnapshotSchema.from_model(context.energy_state) if context.energy_state else None, forecast=ForecastSchema.from_model(context.forecast) if context.forecast else None, - home_load_forecast=( - ConsumptionForecastSchema.from_model(context.home_load_forecast) if context.home_load_forecast else None - ), + home_load=(HomeLoadsConsumptionSchema.from_model(context.home_load) if context.home_load else None), mining_performance=( MiningPerformanceSnapshotSchema.from_model(context.mining_performance) if context.mining_performance @@ -522,7 +522,7 @@ def to_model(self) -> DecisionalContext: energy_source=self.energy_source.to_model() if self.energy_source else None, energy_state=self.energy_state.to_model() if self.energy_state else None, forecast=self.forecast.to_model() if self.forecast else None, - home_load_forecast=self.home_load_forecast.to_model() if self.home_load_forecast else None, + home_load=self.home_load.to_model() if self.home_load else None, mining_performance=(self.mining_performance.to_model() if self.mining_performance else None), sun=self.sun.to_model() if self.sun else None, miner=self.miner.to_model() if self.miner else None, diff --git a/edge_mining/adapters/infrastructure/api/main_api.py b/edge_mining/adapters/infrastructure/api/main_api.py index e8d3378..d854800 100644 --- a/edge_mining/adapters/infrastructure/api/main_api.py +++ b/edge_mining/adapters/infrastructure/api/main_api.py @@ -9,6 +9,7 @@ from edge_mining.__version__ import __version__ from edge_mining.adapters.domain.energy.fast_api.router import router as energy_router from edge_mining.adapters.domain.forecast.fast_api.router import router as forecast_router +from edge_mining.adapters.domain.home_load.fast_api.router import router as home_load_router from edge_mining.adapters.domain.miner.fast_api.router import router as miner_router from edge_mining.adapters.domain.notification.fast_api.router import router as notification_router from edge_mining.adapters.domain.optimization_unit.fast_api.router import router as optimization_unit_router @@ -81,6 +82,7 @@ async def app_lifespan(api_app: FastAPI): app.include_router(rule_engine_router, prefix="/api/v1", tags=["rule_engine"]) app.include_router(notification_router, prefix="/api/v1", tags=["notification"]) app.include_router(forecast_router, prefix="/api/v1", tags=["forecast"]) +app.include_router(home_load_router, prefix="/api/v1", tags=["home_load"]) app.include_router(performance_router, prefix="/api/v1", tags=["performance"]) app.include_router(ws_router, tags=["websocket"]) # Add more routers here (e.g., for configuration) diff --git a/edge_mining/adapters/infrastructure/api/setup.py b/edge_mining/adapters/infrastructure/api/setup.py index c7c3988..0fcab6b 100644 --- a/edge_mining/adapters/infrastructure/api/setup.py +++ b/edge_mining/adapters/infrastructure/api/setup.py @@ -7,6 +7,8 @@ from edge_mining.application.interfaces import ( AdapterServiceInterface, ConfigurationServiceInterface, + HomeLoadHistoryServiceInterface, + LoadForecastTrainingServiceInterface, MinerActionServiceInterface, OptimizationServiceInterface, ) @@ -99,6 +101,26 @@ async def get_optimization_service( return container.services.optimization_service +async def get_home_load_history_service( + container: ServiceContainer = Depends(get_service_container), +) -> HomeLoadHistoryServiceInterface: + """Get HomeLoadHistoryService via dependency injection.""" + return container.services.home_load_history_service + + +async def get_load_forecast_training_service( + container: ServiceContainer = Depends(get_service_container), +) -> LoadForecastTrainingServiceInterface: + """Get LoadForecastTrainingService via dependency injection.""" + service = container.services.load_forecast_training_service + if service is None: + raise HTTPException( + status_code=503, + detail="ML training service not available. Install ML dependencies.", + ) + return service + + async def get_logger( container: ServiceContainer = Depends(get_service_container), ) -> LoggerPort: diff --git a/edge_mining/adapters/infrastructure/cli/commands.py b/edge_mining/adapters/infrastructure/cli/commands.py index da194a4..a99bf71 100644 --- a/edge_mining/adapters/infrastructure/cli/commands.py +++ b/edge_mining/adapters/infrastructure/cli/commands.py @@ -42,7 +42,6 @@ def optimization_unit(): @click.option("--energy_source_id", help="ID of the energy source to use") @click.option("--target_miner_ids", help="Comma-separated list of target miner IDs") @click.option("--policy_id", help="ID of the policy to use") -@click.option("--home_forecast_provider_id", help="ID of the home load forecast provider") @click.option("--performance_tracker_id", help="ID of the performance tracker") @click.option("--notifier_ids", help="Comma-separated list of notifier IDs") @click.pass_context @@ -53,7 +52,6 @@ def create_optimization_unit( energy_source_id_str: str, target_miner_ids_str: str, policy_id_str: str, - home_forecast_provider_id_str: str, performance_tracker_id_str: str, notifier_ids_str: str, ): @@ -78,9 +76,6 @@ def create_optimization_unit( ) energy_source_id = EntityId(cast(UUID, energy_source_id_str)) if energy_source_id_str else None policy_id = EntityId(cast(UUID, policy_id_str)) if policy_id_str else None - home_forecast_provider_id = ( - EntityId(cast(UUID, home_forecast_provider_id_str)) if home_forecast_provider_id_str else None - ) performance_tracker_id = ( EntityId(cast(UUID, performance_tracker_id_str)) if performance_tracker_id_str else None ) @@ -92,7 +87,6 @@ def create_optimization_unit( energy_source_id=energy_source_id, target_miner_ids=target_miner_ids, policy_id=policy_id, - home_forecast_provider_id=home_forecast_provider_id, performance_tracker_id=performance_tracker_id, notifier_ids=notifier_ids, ) diff --git a/edge_mining/adapters/infrastructure/external_services/cli/commands.py b/edge_mining/adapters/infrastructure/external_services/cli/commands.py index b54b328..741bc5a 100644 --- a/edge_mining/adapters/infrastructure/external_services/cli/commands.py +++ b/edge_mining/adapters/infrastructure/external_services/cli/commands.py @@ -12,7 +12,7 @@ from edge_mining.domain.common import EntityId from edge_mining.domain.energy.entities import EnergyMonitor from edge_mining.domain.forecast.entities import ForecastProvider -from edge_mining.domain.home_load.entities import HomeForecastProvider +from edge_mining.domain.home_load.entities import EnergyLoadForecastProvider from edge_mining.domain.miner.entities import MinerController from edge_mining.domain.notification.entities import Notifier from edge_mining.shared.adapter_configs.external_services import ( @@ -215,7 +215,7 @@ def print_external_service_details( EnergyMonitor, MinerController, ForecastProvider, - HomeForecastProvider, + EnergyLoadForecastProvider, Notifier, ] if external_service_linked_entities.energy_monitors: @@ -236,9 +236,9 @@ def print_external_service_details( click.echo(f"-> Name: {e.name} (ID: {e.id})") click.echo("") - if external_service_linked_entities.home_forecast_providers: - click.echo("Home Forecast Providers assigned:") - for e in external_service_linked_entities.home_forecast_providers: + if external_service_linked_entities.energy_load_forecast_providers: + click.echo("Energy Load Forecast Providers assigned:") + for e in external_service_linked_entities.energy_load_forecast_providers: click.echo(f"-> Name: {e.name} (ID: {e.id})") click.echo("") diff --git a/edge_mining/adapters/infrastructure/external_services/schemas.py b/edge_mining/adapters/infrastructure/external_services/schemas.py index 85c3f5a..bb222c7 100644 --- a/edge_mining/adapters/infrastructure/external_services/schemas.py +++ b/edge_mining/adapters/infrastructure/external_services/schemas.py @@ -9,7 +9,10 @@ from edge_mining.adapters.domain.energy.schemas import EnergyMonitorSchema from edge_mining.adapters.domain.forecast.schemas import ForecastProviderSchema -from edge_mining.adapters.domain.home_load.schemas import HomeForecastProviderSchema +from edge_mining.adapters.domain.home_load.schemas import ( + EnergyLoadForecastProviderSchema, + EnergyLoadHistoryProviderSchema, +) from edge_mining.adapters.domain.miner.schemas import MinerControllerSchema from edge_mining.adapters.domain.notification.schemas import NotifierSchema from edge_mining.domain.common import EntityId @@ -184,7 +187,8 @@ class ExternalServiceLinkedEntitiesSchema(BaseModel): miner_controllers: List[MinerControllerSchema] energy_monitors: List[EnergyMonitorSchema] forecast_providers: List[ForecastProviderSchema] - home_forecast_providers: List[HomeForecastProviderSchema] + energy_load_forecast_providers: List[EnergyLoadForecastProviderSchema] + energy_load_history_providers: List[EnergyLoadHistoryProviderSchema] notifiers: List[NotifierSchema] @classmethod @@ -198,8 +202,13 @@ def from_model(cls, linked_entities: ExternalServiceLinkedEntities) -> "External forecast_providers=[ ForecastProviderSchema.from_model(provider) for provider in linked_entities.forecast_providers ], - home_forecast_providers=[ - HomeForecastProviderSchema.from_model(provider) for provider in linked_entities.home_forecast_providers + energy_load_forecast_providers=[ + EnergyLoadForecastProviderSchema.from_model(provider) + for provider in linked_entities.energy_load_forecast_providers + ], + energy_load_history_providers=[ + EnergyLoadHistoryProviderSchema.from_model(provider) + for provider in linked_entities.energy_load_history_providers ], notifiers=[NotifierSchema.from_model(notifier) for notifier in linked_entities.notifiers], ) @@ -210,7 +219,8 @@ def to_model(self) -> ExternalServiceLinkedEntities: miner_controllers=[item.to_model() for item in self.miner_controllers], energy_monitors=[item.to_model() for item in self.energy_monitors], forecast_providers=[item.to_model() for item in self.forecast_providers], - home_forecast_providers=[item.to_model() for item in self.home_forecast_providers], + energy_load_forecast_providers=[item.to_model() for item in self.energy_load_forecast_providers], + energy_load_history_providers=[item.to_model() for item in self.energy_load_history_providers], notifiers=[item.to_model() for item in self.notifiers], ) diff --git a/edge_mining/adapters/infrastructure/homeassistant/homeassistant_api.py b/edge_mining/adapters/infrastructure/homeassistant/homeassistant_api.py index 531bbdb..0d5fae8 100644 --- a/edge_mining/adapters/infrastructure/homeassistant/homeassistant_api.py +++ b/edge_mining/adapters/infrastructure/homeassistant/homeassistant_api.py @@ -13,10 +13,10 @@ import asyncio import math # For isnan -from typing import Optional, Tuple +from typing import List, Optional, Tuple import aiohttp -from homeassistant_api import Client, Domain, Entity +from homeassistant_api import Client, Domain, Entity, History from homeassistant_api.errors import ( EndpointNotFoundError, HomeassistantAPIError, @@ -25,13 +25,14 @@ UnauthorizedError, ) +from edge_mining.adapters.infrastructure.homeassistant.models import EntityHistory, HistoryDataPoint from edge_mining.adapters.infrastructure.homeassistant.utils import ( STATE_SERVICE_MAP, SWITCH_STATE_MAP, SwitchDomain, TurnService, ) -from edge_mining.domain.common import Percentage, WattHours, Watts +from edge_mining.domain.common import Percentage, Timestamp, WattHours, Watts from edge_mining.shared.adapter_configs.external_services import ( ExternalServiceHomeAssistantConfig, ) @@ -285,6 +286,86 @@ async def set_entity_state(self, entity_id: Optional[str], state: str) -> bool: self.logger.error(f"Unexpected error setting Home Assistant entity '{entity_id}': {e}") return False + async def get_entity_history(self, entity_id: str, start: Timestamp, end: Timestamp) -> Optional[EntityHistory]: + """Retrieves the history of a Home Assistant entity.""" + if self.logger: + self.logger.debug(f"Fetching history for entity '{entity_id}' from {start} to {end}...") + + if not entity_id: + if self.logger: + self.logger.debug("No entity_id provided for history fetch.") + return None + + if not self.client: + if self.logger: + self.logger.error("Home Assistant client is not initialized.") + return None + + # The homeassistant_api library's construct_params does not URL-encode + # query values, so a "+" in "+00:00" is interpreted as a space by the + # HA server, causing "Invalid end_time". Work around this by converting + # to naive UTC datetimes – HA treats naive timestamps as UTC. + from datetime import timezone as _tz + + _start = start.astimezone(_tz.utc).replace(tzinfo=None) if start.tzinfo else start + _end = end.astimezone(_tz.utc).replace(tzinfo=None) if end.tzinfo else end + + try: + entity: Optional[Entity] = await self.client.async_get_entity(entity_id=entity_id) + + if not entity: + if self.logger: + self.logger.warning(f"Home Assistant entity '{entity_id}' not found.") + return None + + history_for_entity: Optional[History] = None + async for history in self.client.async_get_entity_histories((entity,), _start, _end): + history_for_entity = history + break + + if not history_for_entity: + if self.logger: + self.logger.debug(f"No history found for entity '{entity_id}'.") + return None + + if self.logger: + self.logger.debug( + f"Retrieved history for entity '{entity_id}' with {len(history_for_entity.states)} entries." + ) + + # history_for_entity.states is a tuple of State objects. + # We iterate over it to create a list of HistoryDataPoint objects. + data_points: List[HistoryDataPoint] = [] + for state in history_for_entity.states: + if state.last_updated is None: + if self.logger: + self.logger.warning( + f"State entry for entity '{entity_id}' has no 'last_updated' timestamp. Skipping entry." + ) + continue + + data_points.append( + HistoryDataPoint( + timestamp=Timestamp(state.last_updated), + value=state.state, + unit=state.attributes.get("unit_of_measurement", ""), + ) + ) + + if self.logger: + self.logger.debug( + f"Retrieved and processed {len(data_points)} history entries for entity '{entity_id}'." + ) + + entity_history = EntityHistory(entity_id=entity_id, history=data_points) + entity_history.sort_by_timestamp() + + return entity_history + except Exception as e: + if self.logger: + self.logger.error(f"Error fetching history for entity '{entity_id}': {e}") + return None + def parse_power( self, state: Optional[str], diff --git a/edge_mining/adapters/infrastructure/homeassistant/models.py b/edge_mining/adapters/infrastructure/homeassistant/models.py new file mode 100644 index 0000000..8fef285 --- /dev/null +++ b/edge_mining/adapters/infrastructure/homeassistant/models.py @@ -0,0 +1,27 @@ +"""Collection of data models for Home Assistant integration.""" + +from dataclasses import dataclass +from typing import List, Optional + +from edge_mining.domain.common import Timestamp + + +@dataclass +class HistoryDataPoint: + """A single data point in the history of an entity.""" + + timestamp: Timestamp + value: str + unit: Optional[str] + + +@dataclass +class EntityHistory: + """Historical data for a specific entity.""" + + entity_id: str + history: List[HistoryDataPoint] + + def sort_by_timestamp(self) -> None: + """Sorts the history data points by their timestamp.""" + self.history.sort(key=lambda point: point.timestamp) diff --git a/edge_mining/adapters/infrastructure/homeassistant/utils.py b/edge_mining/adapters/infrastructure/homeassistant/utils.py index f0cdedd..28fd0da 100644 --- a/edge_mining/adapters/infrastructure/homeassistant/utils.py +++ b/edge_mining/adapters/infrastructure/homeassistant/utils.py @@ -19,6 +19,15 @@ class TurnService(Enum): TURN_OFF = "turn_off" +class EntityState(Enum): + """Enum for the different states of an entity.""" + + ON = "on" + OFF = "off" + UNAVAILABLE = "unavailable" + UNKNOWN = "unknown" + + STATE_SERVICE_MAP: Dict[str, TurnService] = { "on": TurnService.TURN_ON, "true": TurnService.TURN_ON, diff --git a/edge_mining/adapters/infrastructure/rule_engine/custom/helpers.py b/edge_mining/adapters/infrastructure/rule_engine/custom/helpers.py index 0ed78b4..21fcd07 100644 --- a/edge_mining/adapters/infrastructure/rule_engine/custom/helpers.py +++ b/edge_mining/adapters/infrastructure/rule_engine/custom/helpers.py @@ -80,12 +80,21 @@ def _evaluate_logical_group(context: DecisionalContext, group: LogicalGroupSchem @staticmethod def _get_field_value(context: DecisionalContext, field_path: str) -> Any: - """Get value from DecisionalContext using dot notation.""" + """Get value from DecisionalContext using dot notation. + + Supports: + - Attribute access (dataclass fields, properties) + - Dict key lookup (e.g. ``home_load.devices.boiler``) + """ parts = field_path.split(".") current = context for part in parts: - if hasattr(current, part): + if current is None: + return None + if isinstance(current, dict): + current = current.get(part) + elif hasattr(current, part): current = getattr(current, part) else: return None diff --git a/edge_mining/adapters/infrastructure/sheduler/jobs.py b/edge_mining/adapters/infrastructure/sheduler/jobs.py index 9fb0528..c2e8ebe 100644 --- a/edge_mining/adapters/infrastructure/sheduler/jobs.py +++ b/edge_mining/adapters/infrastructure/sheduler/jobs.py @@ -1,8 +1,11 @@ """Job scheduler for running optimization tasks at regular intervals.""" +from typing import Optional + from apscheduler.schedulers.asyncio import AsyncIOScheduler -from edge_mining.application.interfaces import OptimizationServiceInterface +from edge_mining.application.interfaces import HomeLoadHistoryServiceInterface, OptimizationServiceInterface +from edge_mining.application.services.load_forecast_training_service import LoadForecastModelTrainingService from edge_mining.shared.logging.port import LoggerPort from edge_mining.shared.scheduler.port import SchedulerPort from edge_mining.shared.settings.settings import AppSettings @@ -16,13 +19,20 @@ def __init__( optimization_service: OptimizationServiceInterface, logger: LoggerPort, settings: AppSettings, + home_load_history_service: Optional[HomeLoadHistoryServiceInterface] = None, + load_forecast_training_service: Optional[LoadForecastModelTrainingService] = None, ): self.optimization_service = optimization_service + self.home_load_history_service = home_load_history_service + self.load_forecast_training_service = load_forecast_training_service self.logger = logger self.settings = settings self.scheduler = AsyncIOScheduler(timezone=self.settings.timezone) self._job_id = "evaluate_mining" + self._history_collect_job_id = "collect_load_history" + self._history_purge_job_id = "purge_load_history" + self._model_training_job_id = "train_load_forecast_models" async def _run_evaluation_job(self): """Wrapper to call the optimization service's run method.""" @@ -32,6 +42,38 @@ async def _run_evaluation_job(self): except Exception as e: self.logger.error(f"Error during scheduled job: {self._job_id}. {e}") + async def _run_history_collect_job(self): + """Collect power points from all history providers.""" + self.logger.debug(f"Scheduler triggered. Running job: {self._history_collect_job_id}.") + if not self.home_load_history_service: + return + try: + await self.home_load_history_service.collect_all() + except Exception as e: + self.logger.error(f"Error during scheduled job: {self._history_collect_job_id}. {e}") + + async def _run_history_purge_job(self): + """Purge old power points beyond retention window.""" + self.logger.debug(f"Scheduler triggered. Running job: {self._history_purge_job_id}.") + if not self.home_load_history_service: + return + try: + await self.home_load_history_service.purge_all( + retention_days=self.settings.history_retention_days, + ) + except Exception as e: + self.logger.error(f"Error during scheduled job: {self._history_purge_job_id}. {e}") + + async def _run_model_training_job(self): + """Train ML forecast models on collected history.""" + self.logger.debug(f"Scheduler triggered. Running job: {self._model_training_job_id}.") + if not self.load_forecast_training_service: + return + try: + await self.load_forecast_training_service.train_all() + except Exception as e: + self.logger.error(f"Error during scheduled job: {self._model_training_job_id}. {e}") + async def start(self): """Adds the job and starts the scheduler.""" interval = self.settings.scheduler_interval_seconds @@ -45,6 +87,39 @@ async def start(self): replace_existing=True, ) + if self.home_load_history_service: + ingestion_interval = self.settings.history_ingestion_interval_seconds + self.logger.debug( + f"Scheduling history ingestion every {ingestion_interval}s " + f"and purge daily (retention={self.settings.history_retention_days}d)." + ) + self.scheduler.add_job( + self._run_history_collect_job, + "interval", + seconds=ingestion_interval, + id=self._history_collect_job_id, + replace_existing=True, + ) + self.scheduler.add_job( + self._run_history_purge_job, + "cron", + hour=3, + minute=0, + id=self._history_purge_job_id, + replace_existing=True, + ) + + if self.load_forecast_training_service: + self.logger.debug("Scheduling nightly ML model training at 04:00.") + self.scheduler.add_job( + self._run_model_training_job, + "cron", + hour=4, + minute=0, + id=self._model_training_job_id, + replace_existing=True, + ) + self.logger.debug("Scheduler started.") self.scheduler.start() diff --git a/edge_mining/application/interfaces.py b/edge_mining/application/interfaces.py index 0843372..c41c9cb 100644 --- a/edge_mining/application/interfaces.py +++ b/edge_mining/application/interfaces.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Type -from edge_mining.domain.common import DomainEvent, EntityId, Watts +from edge_mining.domain.common import DomainEvent, EntityId, Timestamp, Watts from edge_mining.domain.energy.common import EnergyMonitorAdapter, EnergySourceType from edge_mining.domain.energy.entities import EnergyMonitor, EnergySource from edge_mining.domain.energy.ports import EnergyMonitorPort @@ -12,7 +12,19 @@ from edge_mining.domain.forecast.common import ForecastProviderAdapter from edge_mining.domain.forecast.entities import ForecastProvider from edge_mining.domain.forecast.ports import ForecastProviderPort -from edge_mining.domain.home_load.ports import HomeForecastProviderPort +from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile +from edge_mining.domain.home_load.entities import ( + EnergyLoadForecastProvider, + EnergyLoadHistoryProvider, + LoadConsumptionModel, + LoadDevice, +) +from edge_mining.domain.home_load.common import ( + EnergyLoadForecastProviderAdapter, + EnergyLoadHistoryProviderAdapter, +) +from edge_mining.domain.home_load.ports import EnergyLoadForecastProviderPort, EnergyLoadHistoryProviderPort +from edge_mining.domain.home_load.value_objects import HomeLoadPowerPoint from edge_mining.domain.miner.aggregate_roots import Miner from edge_mining.domain.miner.common import MinerControllerAdapter, MinerFeatureType from edge_mining.domain.miner.entities import MinerController @@ -84,10 +96,16 @@ async def get_forecast_provider(self, energy_source: EnergySource) -> Optional[F @abstractmethod def get_home_load_forecast_provider( - self, home_forecast_provider_id: EntityId - ) -> Optional[HomeForecastProviderPort]: + self, energy_load_forecast_provider_id: EntityId + ) -> Optional[EnergyLoadForecastProviderPort]: """Get an home load forecast provider adapter instance.""" + @abstractmethod + async def get_home_load_history_provider( + self, energy_load_history_provider_id: EntityId, device_id: EntityId + ) -> Optional[EnergyLoadHistoryProviderPort]: + """Get an energy load history provider adapter instance.""" + @abstractmethod async def get_mining_performance_tracker(self, tracker_id: EntityId) -> Optional[MiningPerformanceTrackerPort]: """Get a mining performance tracker adapter instance.""" @@ -129,6 +147,50 @@ async def get_decisional_context(self, optimization_unit_id: EntityId) -> Option """Get the decisional context for a specific optimization unit.""" +class HomeLoadHistoryServiceInterface(ABC): + """Base interface for home load history ingestion and retention.""" + + @abstractmethod + async def collect_all(self, lookback_hours: int = 24) -> None: + """Collect power points from all history providers for all enabled devices.""" + + @abstractmethod + async def collect_devices(self, device_ids: List[EntityId], lookback_hours: int = 24) -> None: + """Collect power points for the specified devices only.""" + + @abstractmethod + async def purge_all(self, retention_days: int = 90) -> None: + """Purge power points older than retention_days for all devices.""" + + @abstractmethod + def get_device_history(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + """Retrieve stored power points for a device in a time window.""" + + @abstractmethod + def clear_device_history(self, device_id: EntityId) -> int: + """Delete all stored power points for a device. Returns the number of rows deleted.""" + + +class LoadForecastTrainingServiceInterface(ABC): + """Base interface for ML model training and model listing.""" + + @abstractmethod + async def train_all(self, weeks_lookback: int = 8) -> None: + """Train models for every device that has sufficient history.""" + + @abstractmethod + async def train_device(self, device_id: EntityId, weeks_lookback: int = 8) -> None: + """Train models for a single device.""" + + @abstractmethod + def get_models(self, device_id: Optional[EntityId] = None) -> List[LoadConsumptionModel]: + """Retrieve trained models, optionally filtered by device.""" + + @abstractmethod + def delete_model(self, model_id: EntityId) -> None: + """Delete a trained model by ID.""" + + class MinerActionServiceInterface(ABC): """Base interface for miner action services in the Edge Mining application.""" @@ -507,8 +569,8 @@ async def create_optimization_unit( policy_id: Optional[EntityId] = None, target_miner_ids: Optional[List[EntityId]] = None, energy_source_id: Optional[EntityId] = None, - home_forecast_provider_id: Optional[EntityId] = None, performance_tracker_id: Optional[EntityId] = None, + home_loads_profile_id: Optional[EntityId] = None, notifier_ids: Optional[List[EntityId]] = None, ) -> Optional[EnergyOptimizationUnit]: """Create an optimization unit into the system.""" @@ -527,7 +589,6 @@ def filter_optimization_units( filter_by_miners: Optional[List[EntityId]] = None, filter_by_energy_source: Optional[EntityId] = None, filter_by_policy: Optional[EntityId] = None, - filter_by_home_forecast_provider: Optional[EntityId] = None, filter_by_performance_tracker: Optional[EntityId] = None, filter_by_notifiers: Optional[List[EntityId]] = None, ) -> List[EnergyOptimizationUnit]: @@ -547,8 +608,8 @@ async def update_optimization_unit( policy_id: Optional[EntityId] = None, target_miner_ids: Optional[List[EntityId]] = None, energy_source_id: Optional[EntityId] = None, - home_forecast_provider_id: Optional[EntityId] = None, performance_tracker_id: Optional[EntityId] = None, + home_loads_profile_id: Optional[EntityId] = None, notifier_ids: Optional[List[EntityId]] = None, ) -> EnergyOptimizationUnit: """Update an optimization unit in the system.""" @@ -589,18 +650,18 @@ async def assign_energy_source_to_optimization_unit( ) -> EnergyOptimizationUnit: """Assign an energy source to an optimization unit.""" - @abstractmethod - async def assign_home_forecast_provider_to_optimization_unit( - self, unit_id: EntityId, home_forecast_provider_id: EntityId - ) -> EnergyOptimizationUnit: - """Assign a home forecast provider to an optimization unit.""" - @abstractmethod async def assign_performance_tracker_to_optimization_unit( self, unit_id: EntityId, performance_tracker_id: EntityId ) -> EnergyOptimizationUnit: """Assign a performance tracker to an optimization unit.""" + @abstractmethod + async def assign_home_loads_profile_to_optimization_unit( + self, unit_id: EntityId, home_loads_profile_id: Optional[EntityId] + ) -> EnergyOptimizationUnit: + """Assign a home loads profile to an optimization unit.""" + @abstractmethod async def assign_notifiers_to_optimization_unit( self, unit_id: EntityId, notifier_ids: List[EntityId] @@ -831,6 +892,93 @@ async def update_forecast_provider( def check_forecast_provider(self, provider: ForecastProvider) -> bool: """Check if a forecast provider is valid and can be used.""" + # --- Home loads Management --- + @abstractmethod + def add_home_loads_profile(self, name: str) -> HomeLoadsProfile: + """Add a home loads profile to the system.""" + + @abstractmethod + def get_home_loads_profile(self, profile_id: EntityId) -> Optional[HomeLoadsProfile]: + """Get a home loads profile by its ID.""" + + @abstractmethod + def list_home_loads_profiles(self) -> List[HomeLoadsProfile]: + """List all home loads profiles in the system.""" + + @abstractmethod + def remove_home_loads_profile(self, profile_id: EntityId) -> HomeLoadsProfile: + """Remove a home loads profile from the system. Raises HomeLoadsProfileNotFoundError.""" + + @abstractmethod + def update_home_loads_profile(self, profile_id: EntityId, name: str) -> HomeLoadsProfile: + """Update a home loads profile in the system. Raises HomeLoadsProfileNotFoundError.""" + + @abstractmethod + def add_load_device_to_profile(self, profile_id: EntityId, load_device: LoadDevice) -> LoadDevice: + """Add a load device to a home loads profile. Raises HomeLoadsProfileNotFoundError.""" + + @abstractmethod + def remove_load_device_from_profile( + self, + profile_id: EntityId, + device_id: EntityId, + ) -> LoadDevice: + """Remove a load device from a home loads profile. Raises on missing profile or device.""" + + # --- Energy Load Forecast Provider Management --- + @abstractmethod + def add_energy_load_forecast_provider(self, provider: EnergyLoadForecastProvider) -> EnergyLoadForecastProvider: + """Add a new energy load forecast provider.""" + + @abstractmethod + def get_energy_load_forecast_provider(self, provider_id: EntityId) -> Optional[EnergyLoadForecastProvider]: + """Get an energy load forecast provider by ID.""" + + @abstractmethod + def list_energy_load_forecast_providers(self) -> List[EnergyLoadForecastProvider]: + """List all energy load forecast providers.""" + + @abstractmethod + def update_energy_load_forecast_provider(self, provider: EnergyLoadForecastProvider) -> EnergyLoadForecastProvider: + """Update an existing energy load forecast provider.""" + + @abstractmethod + def remove_energy_load_forecast_provider(self, provider_id: EntityId) -> EnergyLoadForecastProvider: + """Remove an energy load forecast provider.""" + + @abstractmethod + def get_energy_load_forecast_provider_external_service_adapter( + self, adapter_type: EnergyLoadForecastProviderAdapter + ) -> Optional[ExternalServiceAdapter]: + """Get the external service adapter type for a specific energy load forecast provider adapter type.""" + + # --- Energy Load History Provider Management --- + @abstractmethod + def add_energy_load_history_provider(self, provider: EnergyLoadHistoryProvider) -> EnergyLoadHistoryProvider: + """Add a new energy load history provider.""" + + @abstractmethod + def get_energy_load_history_provider(self, provider_id: EntityId) -> Optional[EnergyLoadHistoryProvider]: + """Get an energy load history provider by ID.""" + + @abstractmethod + def list_energy_load_history_providers(self) -> List[EnergyLoadHistoryProvider]: + """List all energy load history providers.""" + + @abstractmethod + def update_energy_load_history_provider(self, provider: EnergyLoadHistoryProvider) -> EnergyLoadHistoryProvider: + """Update an existing energy load history provider.""" + + @abstractmethod + def remove_energy_load_history_provider(self, provider_id: EntityId) -> EnergyLoadHistoryProvider: + """Remove an energy load history provider.""" + + @abstractmethod + def get_energy_load_history_provider_external_service_adapter( + self, adapter_type: EnergyLoadHistoryProviderAdapter + ) -> Optional[ExternalServiceAdapter]: + """Get the external service adapter type for a specific energy load history provider adapter type.""" + @abstractmethod def get_forecast_provider_config_by_type( self, adapter_type: ForecastProviderAdapter diff --git a/edge_mining/application/services/adapter_service.py b/edge_mining/application/services/adapter_service.py index d4c8be9..610715c 100644 --- a/edge_mining/application/services/adapter_service.py +++ b/edge_mining/application/services/adapter_service.py @@ -8,7 +8,32 @@ from edge_mining.adapters.domain.energy.monitors.home_assistant_api import HomeAssistantAPIEnergyMonitorFactory from edge_mining.adapters.domain.forecast.providers.dummy_solar import DummyForecastProviderFactory from edge_mining.adapters.domain.forecast.providers.home_assistant_api import HomeAssistantForecastProviderFactory -from edge_mining.adapters.domain.home_load.providers.dummy import DummyHomeForecastProvider +from edge_mining.adapters.domain.home_load.forecast_providers.dummy import DummyEnergyLoadForecastProviderFactory +from edge_mining.adapters.domain.home_load.forecast_providers.naive_last_hour import ( + NaiveLastHourForecastProviderFactory, +) +from edge_mining.adapters.domain.home_load.forecast_providers.naive_persistence import ( + NaivePersistenceForecastProviderFactory, +) +from edge_mining.adapters.domain.home_load.forecast_providers.seasonal_baseline import ( + SeasonalBaselineForecastProviderFactory, +) +from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProviderFactory, +) +from edge_mining.adapters.domain.home_load.forecast_providers.statsmodels_hw import ( + StatsmodelsForecastProviderFactory, +) +from edge_mining.adapters.domain.home_load.forecast_providers.typical_profile import ( + TypicalProfileForecastProviderFactory, +) +from edge_mining.adapters.domain.home_load.forecast_providers.xgboost_provider import ( + XGBoostForecastProviderFactory, +) +from edge_mining.adapters.domain.home_load.history_providers.dummy import DummyEnergyLoadHistoryProvider +from edge_mining.adapters.domain.home_load.history_providers.home_assistant_api_history import ( + HomeAssistantAPIEnergyLoadHistoryProviderFactory, +) from edge_mining.adapters.domain.miner.controllers.dummy import DummyMinerController from edge_mining.adapters.domain.miner.controllers.generic_socket_home_assistant_api import ( GenericSocketHomeAssistantAPIMinerControllerAdapterFactory, @@ -33,9 +58,16 @@ from edge_mining.domain.forecast.common import ForecastProviderAdapter from edge_mining.domain.forecast.entities import ForecastProvider from edge_mining.domain.forecast.ports import ForecastProviderPort, ForecastProviderRepository -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter -from edge_mining.domain.home_load.entities import HomeForecastProvider -from edge_mining.domain.home_load.ports import HomeForecastProviderPort, HomeForecastProviderRepository +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter, EnergyLoadHistoryProviderAdapter +from edge_mining.domain.home_load.entities import EnergyLoadForecastProvider, EnergyLoadHistoryProvider, LoadDevice +from edge_mining.domain.home_load.ports import ( + EnergyLoadForecastProviderPort, + EnergyLoadForecastProviderRepository, + EnergyLoadHistoryProviderPort, + EnergyLoadHistoryProviderRepository, + EnergyLoadHistoryRepository, + LoadConsumptionModelRepository, +) from edge_mining.domain.miner.aggregate_roots import Miner from edge_mining.domain.miner.common import MinerControllerAdapter, MinerFeatureType from edge_mining.domain.miner.entities import MinerController @@ -53,6 +85,7 @@ from edge_mining.shared.external_services.entities import ExternalService from edge_mining.shared.external_services.ports import ExternalServicePort, ExternalServiceRepository from edge_mining.shared.interfaces.factories import ( + EnergyLoadForecastAdapterFactory, EnergyMonitorAdapterFactory, ExternalServiceFactory, ForecastAdapterFactory, @@ -75,10 +108,13 @@ def __init__( notifier_repo: NotifierRepository, forecast_provider_repo: ForecastProviderRepository, mining_performance_tracker_repo: MiningPerformanceTrackerRepository, - home_forecast_provider_repo: HomeForecastProviderRepository, + energy_load_forecast_provider_repo: EnergyLoadForecastProviderRepository, + energy_load_history_provider_repo: EnergyLoadHistoryProviderRepository, + home_load_history_repo: EnergyLoadHistoryRepository, external_service_repo: ExternalServiceRepository, event_bus: EventBusInterface, logger: Optional[LoggerPort] = None, + load_consumption_model_repo: Optional[LoadConsumptionModelRepository] = None, ): self.energy_monitor_repo = energy_monitor_repo self.miner_controller_repo = miner_controller_repo @@ -86,8 +122,11 @@ def __init__( self.notifier_repo = notifier_repo self.forecast_provider_repo = forecast_provider_repo self.mining_performance_tracker_repo = mining_performance_tracker_repo - self.home_forecast_provider_repo = home_forecast_provider_repo + self.energy_load_forecast_provider_repo = energy_load_forecast_provider_repo + self.energy_load_history_provider_repo = energy_load_history_provider_repo + self.home_load_history_repo = home_load_history_repo self.external_service_repo = external_service_repo + self.load_consumption_model_repo = load_consumption_model_repo # Cache for already created instances self._instance_cache: Dict[ EntityId, @@ -97,7 +136,8 @@ def __init__( MinerFeaturePort, NotificationPort, ForecastProviderPort, - HomeForecastProviderPort, + EnergyLoadForecastProviderPort, + EnergyLoadHistoryProviderPort, MiningPerformanceTrackerPort, ] ], @@ -481,19 +521,19 @@ async def _initialize_forecast_provider_adapter( ) return None - def _initialize_home_forecast_provider_adapter( - self, home_forecast_provider: HomeForecastProvider - ) -> Optional[HomeForecastProviderPort]: + def _initialize_energy_load_forecast_provider_adapter( + self, energy_load_forecast_provider: EnergyLoadForecastProvider + ) -> Optional[EnergyLoadForecastProviderPort]: """Initialize a home forecast provider adapter.""" # If the adapter has already been created, we use it. - if home_forecast_provider.id in self._instance_cache: + if energy_load_forecast_provider.id in self._instance_cache: if self.logger: self.logger.debug( f"Returning cached adapter instance " - f"for home forecast provider ID {home_forecast_provider.id} " - f"(Type: {home_forecast_provider.adapter_type})" + f"for home forecast provider ID {energy_load_forecast_provider.id} " + f"(Type: {energy_load_forecast_provider.adapter_type})" ) - cached_instance = self._instance_cache[home_forecast_provider.id] + cached_instance = self._instance_cache[energy_load_forecast_provider.id] if not cached_instance: # If the cached instance is None, we return it @@ -501,16 +541,16 @@ def _initialize_home_forecast_provider_adapter( if self.logger: self.logger.warning( f"Cached instance for home forecast provider ID " - f"{home_forecast_provider.id} is None. Reinitializing adapter." + f"{energy_load_forecast_provider.id} is None. Reinitializing adapter." ) return None # Check if the cached instance is of the correct type - if not isinstance(cached_instance, HomeForecastProviderPort): + if not isinstance(cached_instance, EnergyLoadForecastProviderPort): if self.logger: self.logger.warning( f"Cached instance for home forecast provider ID " - f"{home_forecast_provider.id} is not of type HomeForecastProviderPort. " + f"{energy_load_forecast_provider.id} is not of type EnergyLoadForecastProviderPort. " "Reinitializing adapter." ) return None @@ -519,23 +559,42 @@ def _initialize_home_forecast_provider_adapter( return cached_instance try: - if home_forecast_provider.adapter_type == HomeForecastProviderAdapter.DUMMY: - # --- Dummy Home Forecast Provider --- - # TODO - Add configuration parameters for DummyHomeForecastProvider - # For now, we use a default load power max of 800W. - instance = DummyHomeForecastProvider(load_power_max=800) + factory: Optional[EnergyLoadForecastAdapterFactory] = None + + if energy_load_forecast_provider.adapter_type == EnergyLoadForecastProviderAdapter.DUMMY: + factory = DummyEnergyLoadForecastProviderFactory() + elif energy_load_forecast_provider.adapter_type == EnergyLoadForecastProviderAdapter.NAIVE_LAST_HOUR: + factory = NaiveLastHourForecastProviderFactory() + elif energy_load_forecast_provider.adapter_type == EnergyLoadForecastProviderAdapter.NAIVE_PERSISTENCE: + factory = NaivePersistenceForecastProviderFactory() + elif energy_load_forecast_provider.adapter_type == EnergyLoadForecastProviderAdapter.SEASONAL_BASELINE: + factory = SeasonalBaselineForecastProviderFactory() + elif energy_load_forecast_provider.adapter_type == EnergyLoadForecastProviderAdapter.SKFORECAST: + factory = SkforecastForecastProviderFactory(model_repo=self.load_consumption_model_repo) + elif energy_load_forecast_provider.adapter_type == EnergyLoadForecastProviderAdapter.STATSMODELS: + factory = StatsmodelsForecastProviderFactory(model_repo=self.load_consumption_model_repo) + elif energy_load_forecast_provider.adapter_type == EnergyLoadForecastProviderAdapter.TYPICAL_PROFILE: + factory = TypicalProfileForecastProviderFactory() + elif energy_load_forecast_provider.adapter_type == EnergyLoadForecastProviderAdapter.XGBOOST: + factory = XGBoostForecastProviderFactory(model_repo=self.load_consumption_model_repo) else: raise ValueError( - f"Unsupported home forecast provider adapter type: {home_forecast_provider.adapter_type}" + f"Unsupported home forecast provider adapter type: {energy_load_forecast_provider.adapter_type}" ) - self._instance_cache[home_forecast_provider.id] = instance + instance = factory.create( + config=energy_load_forecast_provider.config, + logger=self.logger, + external_service=None, + ) + + self._instance_cache[energy_load_forecast_provider.id] = instance return instance except Exception as e: if self.logger: self.logger.error( - f"Failed to initialize adapter '{home_forecast_provider.name}' " - f"(Type: {home_forecast_provider.adapter_type}) using factory: {e}" + f"Failed to initialize adapter '{energy_load_forecast_provider.name}' " + f"(Type: {energy_load_forecast_provider.adapter_type}) using factory: {e}" ) return None @@ -781,17 +840,100 @@ async def get_forecast_provider(self, energy_source: EnergySource) -> Optional[F return await self._initialize_forecast_provider_adapter(energy_source, forecast_provider) def get_home_load_forecast_provider( - self, home_forecast_provider_id: EntityId - ) -> Optional[HomeForecastProviderPort]: + self, energy_load_forecast_provider_id: EntityId + ) -> Optional[EnergyLoadForecastProviderPort]: """Get an home load forecast provider adapter instance.""" - home_forecast_provider = self.home_forecast_provider_repo.get_by_id(home_forecast_provider_id) - if not home_forecast_provider: + energy_load_forecast_provider = self.energy_load_forecast_provider_repo.get_by_id( + energy_load_forecast_provider_id + ) + if not energy_load_forecast_provider: + if self.logger: + self.logger.error( + f"Home Forecast Provider ID {energy_load_forecast_provider_id} not found or not a Home Forecast Provider." + ) + return None + return self._initialize_energy_load_forecast_provider_adapter(energy_load_forecast_provider) + + async def _initialize_energy_load_history_provider_adapter( + self, energy_load_history_provider: EnergyLoadHistoryProvider, device_id: EntityId + ) -> Optional[EnergyLoadHistoryProviderPort]: + """Initialize an energy load history provider adapter.""" + cache_key = energy_load_history_provider.id + if cache_key in self._instance_cache: + cached_instance = self._instance_cache[cache_key] + if cached_instance and isinstance(cached_instance, EnergyLoadHistoryProviderPort): + return cached_instance + return None + + # Resolve external service if needed + external_service: Optional[ExternalServicePort] = None + if energy_load_history_provider.external_service_id: + external_service = await self.get_external_service(energy_load_history_provider.external_service_id) + if not external_service: + raise ValueError( + f"Unable to load external service {energy_load_history_provider.external_service_id} " + f"for history provider {energy_load_history_provider.name}" + ) + + try: + instance: Optional[EnergyLoadHistoryProviderPort] = None + + if energy_load_history_provider.adapter_type == EnergyLoadHistoryProviderAdapter.DUMMY: + instance = DummyEnergyLoadHistoryProvider( + device_id=device_id, + history_repo=self.home_load_history_repo, + logger=self.logger, + ) + elif energy_load_history_provider.adapter_type == EnergyLoadHistoryProviderAdapter.HOME_ASSISTANT_API: + if not energy_load_history_provider.config: + raise ValueError( + "EnergyLoadHistoryProvider config is required for HomeAssistantAPI history provider." + ) + if not external_service: + raise ValueError( + f"External service is required for HomeAssistantAPI history provider " + f"'{energy_load_history_provider.name}'. " + f"Please set external_service_id on the provider." + ) + # Resolve the LoadDevice for the factory + factory = HomeAssistantAPIEnergyLoadHistoryProviderFactory( + history_repo=self.home_load_history_repo, + ) + # Build a minimal LoadDevice for the factory binding + + load_device = LoadDevice(id=device_id) + factory.from_load_device(load_device) + instance = factory.create( + config=energy_load_history_provider.config, + logger=self.logger, + external_service=external_service, + ) + else: + raise ValueError( + f"Unsupported energy load history provider adapter type: " + f"{energy_load_history_provider.adapter_type}" + ) + + self._instance_cache[cache_key] = instance + return instance + except Exception as e: if self.logger: self.logger.error( - f"Home Forecast Provider ID {home_forecast_provider_id} not found or not a Home Forecast Provider." + f"Failed to initialize adapter '{energy_load_history_provider.name}' " + f"(Type: {energy_load_history_provider.adapter_type}): {e}" ) return None - return self._initialize_home_forecast_provider_adapter(home_forecast_provider) + + async def get_home_load_history_provider( + self, energy_load_history_provider_id: EntityId, device_id: EntityId + ) -> Optional[EnergyLoadHistoryProviderPort]: + """Get an energy load history provider adapter instance.""" + energy_load_history_provider = self.energy_load_history_provider_repo.get_by_id(energy_load_history_provider_id) + if not energy_load_history_provider: + if self.logger: + self.logger.error(f"Home History Provider ID {energy_load_history_provider_id} not found.") + return None + return await self._initialize_energy_load_history_provider_adapter(energy_load_history_provider, device_id) async def get_mining_performance_tracker(self, tracker_id: EntityId) -> Optional[MiningPerformanceTrackerPort]: """Get a mining performance tracker adapter instance.""" diff --git a/edge_mining/application/services/configuration_service.py b/edge_mining/application/services/configuration_service.py index 34579d1..e738a24 100644 --- a/edge_mining/application/services/configuration_service.py +++ b/edge_mining/application/services/configuration_service.py @@ -24,9 +24,21 @@ from edge_mining.domain.forecast.entities import ForecastProvider from edge_mining.domain.forecast.exceptions import ForecastProviderConfigurationError, ForecastProviderNotFoundError from edge_mining.domain.forecast.ports import ForecastProviderRepository -from edge_mining.domain.home_load.entities import HomeForecastProvider -from edge_mining.domain.home_load.exceptions import HomeForecastProviderNotFoundError -from edge_mining.domain.home_load.ports import HomeForecastProviderRepository +from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter, EnergyLoadHistoryProviderAdapter +from edge_mining.domain.home_load.entities import EnergyLoadForecastProvider, EnergyLoadHistoryProvider, LoadDevice +from edge_mining.domain.home_load.exceptions import ( + EnergyLoadForecastProviderConfigurationError, + EnergyLoadForecastProviderNotFoundError, + EnergyLoadHistoryProviderConfigurationError, + EnergyLoadHistoryProviderNotFoundError, + HomeLoadsProfileNotFoundError, +) +from edge_mining.domain.home_load.ports import ( + EnergyLoadForecastProviderRepository, + EnergyLoadHistoryProviderRepository, + HomeLoadsProfileRepository, +) from edge_mining.domain.miner.aggregate_roots import Miner from edge_mining.domain.miner.common import MinerControllerAdapter, MinerFeatureType from edge_mining.domain.miner.entities import MinerController @@ -79,6 +91,10 @@ FORECAST_PROVIDER_CONFIG_TYPE_MAP, FORECAST_PROVIDER_TYPE_EXTERNAL_SERVICE_MAP, ) +from edge_mining.shared.adapter_maps.home_load import ( + ENERGY_LOAD_FORECAST_PROVIDER_EXTERNAL_SERVICE_MAP, + ENERGY_LOAD_HISTORY_PROVIDER_EXTERNAL_SERVICE_MAP, +) from edge_mining.shared.adapter_maps.miner import ( MINER_CONTROLLER_CONFIG_TYPE_MAP, MINER_CONTROLLER_TYPE_EXTERNAL_SERVICE_MAP, @@ -128,9 +144,13 @@ def __init__( self.policy_repo: OptimizationPolicyRepository = persistence_settings.policy_repo self.optimization_unit_repo: EnergyOptimizationUnitRepository = persistence_settings.optimization_unit_repo self.forecast_provider_repo: ForecastProviderRepository = persistence_settings.forecast_provider_repo - self.home_forecast_provider_repo: HomeForecastProviderRepository = ( - persistence_settings.home_forecast_provider_repo + self.energy_load_forecast_provider_repo: EnergyLoadForecastProviderRepository = ( + persistence_settings.energy_load_forecast_provider_repo ) + self.energy_load_history_provider_repo: EnergyLoadHistoryProviderRepository = ( + persistence_settings.energy_load_history_provider_repo + ) + self.home_profile_repo: HomeLoadsProfileRepository = persistence_settings.home_profile_repo self.mining_performance_tracker_repo: MiningPerformanceTrackerRepository = ( persistence_settings.mining_performance_tracker_repo ) @@ -186,8 +206,11 @@ def get_entities_by_external_service(self, service_id: EntityId) -> ExternalServ miner_controllers: List[MinerController] = self.miner_controller_repo.get_by_external_service_id(service_id) energy_monitors: List[EnergyMonitor] = self.energy_monitor_repo.get_by_external_service_id(service_id) forecast_providers: List[ForecastProvider] = self.forecast_provider_repo.get_by_external_service_id(service_id) - home_forecast_providers: List[HomeForecastProvider] = ( - self.home_forecast_provider_repo.get_by_external_service_id(service_id) + energy_load_forecast_providers: List[EnergyLoadForecastProvider] = ( + self.energy_load_forecast_provider_repo.get_by_external_service_id(service_id) + ) + energy_load_history_providers: List[EnergyLoadHistoryProvider] = ( + self.energy_load_history_provider_repo.get_by_external_service_id(service_id) ) notifiers: List[Notifier] = self.notifier_repo.get_by_external_service_id(service_id) @@ -195,7 +218,8 @@ def get_entities_by_external_service(self, service_id: EntityId) -> ExternalServ miner_controllers=miner_controllers, energy_monitors=energy_monitors, forecast_providers=forecast_providers, - home_forecast_providers=home_forecast_providers, + energy_load_forecast_providers=energy_load_forecast_providers, + energy_load_history_providers=energy_load_history_providers, notifiers=notifiers, ) return external_service_linked_entities @@ -233,13 +257,22 @@ async def unlink_external_service(self, service_id: EntityId) -> None: self.forecast_provider_repo.update(forecast_provider) # Unlink from home forecast providers - for home_forecast_provider in external_service_linked_entities.home_forecast_providers: + for energy_load_forecast_provider in external_service_linked_entities.energy_load_forecast_providers: self.logger.debug( - f"Unlinking home forecast provider {home_forecast_provider.name} " - f"({home_forecast_provider.id}) from external service {service_id}" + f"Unlinking home forecast provider {energy_load_forecast_provider.name} " + f"({energy_load_forecast_provider.id}) from external service {service_id}" ) - home_forecast_provider.external_service_id = None - self.home_forecast_provider_repo.update(home_forecast_provider) + energy_load_forecast_provider.external_service_id = None + self.energy_load_forecast_provider_repo.update(energy_load_forecast_provider) + + # Unlink from home history providers + for energy_load_history_provider in external_service_linked_entities.energy_load_history_providers: + self.logger.debug( + f"Unlinking home history provider {energy_load_history_provider.name} " + f"({energy_load_history_provider.id}) from external service {service_id}" + ) + energy_load_history_provider.external_service_id = None + self.energy_load_history_provider_repo.update(energy_load_history_provider) # Unlink from notifiers for notifier in external_service_linked_entities.notifiers: @@ -879,8 +912,8 @@ async def create_optimization_unit( policy_id: Optional[EntityId] = None, target_miner_ids: Optional[List[EntityId]] = None, energy_source_id: Optional[EntityId] = None, - home_forecast_provider_id: Optional[EntityId] = None, performance_tracker_id: Optional[EntityId] = None, + home_loads_profile_id: Optional[EntityId] = None, notifier_ids: Optional[List[EntityId]] = None, ) -> Optional[EnergyOptimizationUnit]: """Create an optimization unit into the system.""" @@ -893,8 +926,8 @@ async def create_optimization_unit( policy_id=policy_id, target_miner_ids=target_miner_ids or [], energy_source_id=energy_source_id, - home_forecast_provider_id=home_forecast_provider_id, performance_tracker_id=performance_tracker_id, + home_loads_profile=home_loads_profile_id, notifier_ids=notifier_ids or [], ) @@ -922,7 +955,6 @@ def filter_optimization_units( filter_by_miners: Optional[List[EntityId]] = None, filter_by_energy_source: Optional[EntityId] = None, filter_by_policy: Optional[EntityId] = None, - filter_by_home_forecast_provider: Optional[EntityId] = None, filter_by_performance_tracker: Optional[EntityId] = None, filter_by_notifiers: Optional[List[EntityId]] = None, ) -> List[EnergyOptimizationUnit]: @@ -936,8 +968,6 @@ def filter_optimization_units( eous = [eou for eou in eous if eou.energy_source_id == filter_by_energy_source] if filter_by_policy is not None: eous = [eou for eou in eous if eou.policy_id == filter_by_policy] - if filter_by_home_forecast_provider is not None: - eous = [eou for eou in eous if eou.home_forecast_provider_id == filter_by_home_forecast_provider] if filter_by_performance_tracker is not None: eous = [eou for eou in eous if eou.performance_tracker_id == filter_by_performance_tracker] if filter_by_notifiers is not None: @@ -966,8 +996,8 @@ async def update_optimization_unit( policy_id: Optional[EntityId] = None, target_miner_ids: Optional[List[EntityId]] = None, energy_source_id: Optional[EntityId] = None, - home_forecast_provider_id: Optional[EntityId] = None, performance_tracker_id: Optional[EntityId] = None, + home_loads_profile_id: Optional[EntityId] = None, notifier_ids: Optional[List[EntityId]] = None, ) -> EnergyOptimizationUnit: """Update an optimization unit in the system.""" @@ -989,10 +1019,10 @@ async def update_optimization_unit( optimization_unit.target_miner_ids = target_miner_ids if energy_source_id is not None: optimization_unit.energy_source_id = energy_source_id - if home_forecast_provider_id is not None: - optimization_unit.home_forecast_provider_id = home_forecast_provider_id if performance_tracker_id is not None: optimization_unit.performance_tracker_id = performance_tracker_id + if home_loads_profile_id is not None: + optimization_unit.assign_home_loads_profile(home_loads_profile_id) if notifier_ids is not None: optimization_unit.notifier_ids = notifier_ids @@ -1139,36 +1169,35 @@ async def assign_energy_source_to_optimization_unit( return optimization_unit - async def assign_home_forecast_provider_to_optimization_unit( - self, unit_id: EntityId, home_forecast_provider_id: EntityId + async def assign_performance_tracker_to_optimization_unit( + self, unit_id: EntityId, performance_tracker_id: EntityId ) -> EnergyOptimizationUnit: - """Assign a home forecast provider to an optimization unit.""" - self.logger.info(f"Assigning home forecast provider {home_forecast_provider_id} to optimization unit {unit_id}") + """Assign a performance tracker to an optimization unit.""" + self.logger.info(f"Assigning performance tracker {performance_tracker_id} to optimization unit {unit_id}") optimization_unit = self.optimization_unit_repo.get_by_id(unit_id) if not optimization_unit: raise OptimizationUnitNotFoundError(f"Optimization Unit with ID {unit_id} not found.") - optimization_unit.home_forecast_provider_id = home_forecast_provider_id + optimization_unit.performance_tracker_id = performance_tracker_id self.check_optimization_unit(optimization_unit) self.optimization_unit_repo.update(optimization_unit) return optimization_unit - async def assign_performance_tracker_to_optimization_unit( - self, unit_id: EntityId, performance_tracker_id: EntityId + async def assign_home_loads_profile_to_optimization_unit( + self, unit_id: EntityId, home_loads_profile_id: Optional[EntityId] ) -> EnergyOptimizationUnit: - """Assign a performance tracker to an optimization unit.""" - self.logger.info(f"Assigning performance tracker {performance_tracker_id} to optimization unit {unit_id}") + """Assign a home loads profile to an optimization unit.""" + self.logger.info(f"Assigning home loads profile {home_loads_profile_id} to optimization unit {unit_id}") optimization_unit = self.optimization_unit_repo.get_by_id(unit_id) if not optimization_unit: raise OptimizationUnitNotFoundError(f"Optimization Unit with ID {unit_id} not found.") - optimization_unit.performance_tracker_id = performance_tracker_id - self.check_optimization_unit(optimization_unit) + optimization_unit.assign_home_loads_profile(home_loads_profile_id) self.optimization_unit_repo.update(optimization_unit) return optimization_unit @@ -1277,16 +1306,6 @@ def check_optimization_unit(self, optimization_unit: EnergyOptimizationUnit, str f"Optimization Unit {optimization_unit.id} must have an energy source assigned." ) - # Check if the home forecast provider is valid - if optimization_unit.home_forecast_provider_id: - home_forecast_provider = self.home_forecast_provider_repo.get_by_id( - optimization_unit.home_forecast_provider_id - ) - if not home_forecast_provider: - raise HomeForecastProviderNotFoundError( - f"Home Forecast Provider with ID {optimization_unit.home_forecast_provider_id} not found." - ) - # Check if the performance tracker is valid if optimization_unit.performance_tracker_id: performance_tracker = self.mining_performance_tracker_repo.get_by_id( @@ -1845,6 +1864,153 @@ def get_notifier_external_service_adapter( ) return NOTIFIER_TYPE_EXTERNAL_SERVICE_MAP.get(adapter_type, None) + # --- Home Loads Profile Management --- + def add_home_loads_profile(self, name: str) -> HomeLoadsProfile: + """Create and persist a new home loads profile.""" + profile = HomeLoadsProfile(name=name) + self.home_profile_repo.add(profile) + self.logger.info(f"Added home loads profile '{profile.name}' ({profile.id}).") + return profile + + def get_home_loads_profile(self, profile_id: EntityId) -> Optional[HomeLoadsProfile]: + """Get a home loads profile by ID.""" + return self.home_profile_repo.get_by_id(profile_id) + + def list_home_loads_profiles(self) -> List[HomeLoadsProfile]: + """List all home loads profiles.""" + return self.home_profile_repo.get_all() + + def update_home_loads_profile(self, profile_id: EntityId, name: str) -> HomeLoadsProfile: + """Rename an existing home loads profile.""" + profile = self.home_profile_repo.get_by_id(profile_id) + if not profile: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found.") + profile.name = name + self.home_profile_repo.update(profile) + return profile + + def remove_home_loads_profile(self, profile_id: EntityId) -> HomeLoadsProfile: + """Remove a home loads profile by ID.""" + profile = self.home_profile_repo.get_by_id(profile_id) + if not profile: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found.") + self.home_profile_repo.remove(profile_id) + return profile + + def add_load_device_to_profile(self, profile_id: EntityId, load_device: LoadDevice) -> LoadDevice: + """Append a load device to a profile (raises on duplicate device name).""" + profile = self.home_profile_repo.get_by_id(profile_id) + if not profile: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found.") + profile.add_device(load_device) + self.home_profile_repo.update(profile) + return load_device + + def remove_load_device_from_profile(self, profile_id: EntityId, device_id: EntityId) -> LoadDevice: + """Remove a load device from a profile.""" + profile = self.home_profile_repo.get_by_id(profile_id) + if not profile: + raise HomeLoadsProfileNotFoundError(f"Home Loads Profile with ID {profile_id} not found.") + removed = profile.remove_device(device_id) + self.home_profile_repo.update(profile) + return removed + + # --- Energy Load Forecast Provider Management --- + def add_energy_load_forecast_provider(self, provider: EnergyLoadForecastProvider) -> EnergyLoadForecastProvider: + """Add a new energy load forecast provider.""" + self.energy_load_forecast_provider_repo.add(provider) + self.logger.info(f"Added energy load forecast provider '{provider.name}' ({provider.id}).") + return provider + + def get_energy_load_forecast_provider(self, provider_id: EntityId) -> Optional[EnergyLoadForecastProvider]: + """Get an energy load forecast provider by ID.""" + return self.energy_load_forecast_provider_repo.get_by_id(provider_id) + + def list_energy_load_forecast_providers(self) -> List[EnergyLoadForecastProvider]: + """List all energy load forecast providers.""" + return self.energy_load_forecast_provider_repo.get_all() + + def update_energy_load_forecast_provider(self, provider: EnergyLoadForecastProvider) -> EnergyLoadForecastProvider: + """Update an existing energy load forecast provider.""" + existing = self.energy_load_forecast_provider_repo.get_by_id(provider.id) + if not existing: + raise EnergyLoadForecastProviderNotFoundError( + f"Energy Load Forecast Provider with ID {provider.id} not found." + ) + self.energy_load_forecast_provider_repo.update(provider) + self.logger.info(f"Updated energy load forecast provider '{provider.name}' ({provider.id}).") + return provider + + def remove_energy_load_forecast_provider(self, provider_id: EntityId) -> EnergyLoadForecastProvider: + """Remove an energy load forecast provider.""" + provider = self.energy_load_forecast_provider_repo.get_by_id(provider_id) + if not provider: + raise EnergyLoadForecastProviderNotFoundError( + f"Energy Load Forecast Provider with ID {provider_id} not found." + ) + self.energy_load_forecast_provider_repo.remove(provider_id) + self.logger.info(f"Removed energy load forecast provider '{provider.name}' ({provider.id}).") + return provider + + # --- Energy Load History Provider Management --- + def add_energy_load_history_provider(self, provider: EnergyLoadHistoryProvider) -> EnergyLoadHistoryProvider: + """Add a new energy load history provider.""" + self.energy_load_history_provider_repo.add(provider) + self.logger.info(f"Added energy load history provider '{provider.name}' ({provider.id}).") + return provider + + def get_energy_load_history_provider(self, provider_id: EntityId) -> Optional[EnergyLoadHistoryProvider]: + """Get an energy load history provider by ID.""" + return self.energy_load_history_provider_repo.get_by_id(provider_id) + + def list_energy_load_history_providers(self) -> List[EnergyLoadHistoryProvider]: + """List all energy load history providers.""" + return self.energy_load_history_provider_repo.get_all() + + def update_energy_load_history_provider(self, provider: EnergyLoadHistoryProvider) -> EnergyLoadHistoryProvider: + """Update an existing energy load history provider.""" + existing = self.energy_load_history_provider_repo.get_by_id(provider.id) + if not existing: + raise EnergyLoadHistoryProviderNotFoundError( + f"Energy Load History Provider with ID {provider.id} not found." + ) + self.energy_load_history_provider_repo.update(provider) + self.logger.info(f"Updated energy load history provider '{provider.name}' ({provider.id}).") + return provider + + def remove_energy_load_history_provider(self, provider_id: EntityId) -> EnergyLoadHistoryProvider: + """Remove an energy load history provider.""" + provider = self.energy_load_history_provider_repo.get_by_id(provider_id) + if not provider: + raise EnergyLoadHistoryProviderNotFoundError( + f"Energy Load History Provider with ID {provider_id} not found." + ) + self.energy_load_history_provider_repo.remove(provider_id) + self.logger.info(f"Removed energy load history provider '{provider.name}' ({provider.id}).") + return provider + + def get_energy_load_forecast_provider_external_service_adapter( + self, adapter_type: EnergyLoadForecastProviderAdapter + ) -> Optional[ExternalServiceAdapter]: + """Get the external service adapter type for a specific energy load forecast provider adapter type.""" + self.logger.debug(f"Getting external service adapter for energy load forecast provider adapter {adapter_type}") + if adapter_type not in ENERGY_LOAD_FORECAST_PROVIDER_EXTERNAL_SERVICE_MAP: + raise EnergyLoadForecastProviderConfigurationError( + f"Adapter type {adapter_type} is not supported for energy load forecast provider configuration." + ) + return ENERGY_LOAD_FORECAST_PROVIDER_EXTERNAL_SERVICE_MAP.get(adapter_type, None) + + def get_energy_load_history_provider_external_service_adapter( + self, adapter_type: EnergyLoadHistoryProviderAdapter + ) -> Optional[ExternalServiceAdapter]: + """Get the external service adapter type for a specific energy load history provider adapter type.""" + self.logger.debug(f"Getting external service adapter for energy load history provider adapter {adapter_type}") + if adapter_type not in ENERGY_LOAD_HISTORY_PROVIDER_EXTERNAL_SERVICE_MAP: + raise EnergyLoadHistoryProviderConfigurationError( + f"Adapter type {adapter_type} is not supported for energy load history provider configuration." + ) + return ENERGY_LOAD_HISTORY_PROVIDER_EXTERNAL_SERVICE_MAP.get(adapter_type, None) + # --- Mining Performance Tracker Management --- async def add_mining_performance_tracker( self, @@ -1999,8 +2165,7 @@ def get_mining_performance_tracker_external_service_adapter( self.logger.debug(f"Getting external service adapter for mining performance tracker adapter {adapter_type}") if adapter_type not in MINING_PERFORMANCE_TRACKER_TYPE_EXTERNAL_SERVICE_MAP: raise MiningPerformanceTrackerConfigurationError( - f"Adapter type {adapter_type} is not supported " - "for mining performance tracker external service mapping." + f"Adapter type {adapter_type} is not supported for mining performance tracker external service mapping." ) return MINING_PERFORMANCE_TRACKER_TYPE_EXTERNAL_SERVICE_MAP.get(adapter_type, None) diff --git a/edge_mining/application/services/home_load_history_service.py b/edge_mining/application/services/home_load_history_service.py new file mode 100644 index 0000000..50589f1 --- /dev/null +++ b/edge_mining/application/services/home_load_history_service.py @@ -0,0 +1,179 @@ +"""Service for collecting and purging home load consumption history.""" + +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from edge_mining.application.interfaces import ( + AdapterServiceInterface, + EventBusInterface, + HomeLoadHistoryServiceInterface, +) +from edge_mining.domain.common import EntityId, Timestamp +from edge_mining.domain.home_load.events import ( + LoadConsumptionHistoryCollectedEvent, + LoadConsumptionHistoryPurgedEvent, +) +from edge_mining.domain.home_load.ports import ( + EnergyLoadHistoryRepository, + HomeLoadsProfileRepository, +) +from edge_mining.domain.home_load.value_objects import HomeLoadPowerPoint +from edge_mining.shared.logging.port import LoggerPort + + +class HomeLoadHistoryService(HomeLoadHistoryServiceInterface): + """Collects power-point data from history providers and manages retention.""" + + def __init__( + self, + home_loads_repo: HomeLoadsProfileRepository, + home_load_history_repo: EnergyLoadHistoryRepository, + adapter_service: AdapterServiceInterface, + event_bus: Optional[EventBusInterface] = None, + logger: Optional[LoggerPort] = None, + ): + self.home_loads_repo = home_loads_repo + self.home_load_history_repo = home_load_history_repo + self.adapter_service = adapter_service + self._event_bus = event_bus + self.logger = logger + + async def collect_all(self, lookback_hours: int = 24) -> None: + """Collect power points from all history providers for all enabled devices. + + For each enabled LoadDevice that has an energy_load_history_provider_id, + fetches new power points since the last known timestamp (delta ingestion) + and persists them in the history repository. + """ + profiles = self.home_loads_repo.get_all() + if not profiles: + if self.logger: + self.logger.debug("No home load profiles found. Skipping history collection.") + return + + for profile in profiles: + for device in profile.devices: + if not device.enabled: + continue + if not device.energy_load_history_provider_id: + continue + await self._collect_for_device( + device_id=device.id, + device_name=device.name, + provider_id=device.energy_load_history_provider_id, + lookback_hours=lookback_hours, + ) + + async def _collect_for_device( + self, + device_id: EntityId, + device_name: str, + provider_id: EntityId, + lookback_hours: int = 24, + ) -> None: + """Collect power points for a single device from its history provider.""" + history_provider = await self.adapter_service.get_home_load_history_provider(provider_id, device_id) + if not history_provider: + if self.logger: + self.logger.warning(f"History provider {provider_id} not found for device '{device_name}'. Skipping.") + return + + now = Timestamp(datetime.now(timezone.utc)) + last_ts = self.home_load_history_repo.get_latest_timestamp(device_id) + if last_ts is not None: + start = last_ts + else: + start = Timestamp(now - timedelta(hours=lookback_hours)) + + try: + power_points = await history_provider.get_power_points(start, now) + except Exception as e: + if self.logger: + self.logger.error( + f"Error fetching power points for device '{device_name}' " f"from provider {provider_id}: {e}" + ) + return + + if not power_points: + return + + self.home_load_history_repo.add_power_points(device_id, power_points) + if self.logger: + self.logger.debug(f"Collected {len(power_points)} power points for device '{device_name}'.") + + if self._event_bus: + await self._event_bus.publish( + LoadConsumptionHistoryCollectedEvent( + device_id=device_id, + device_name=device_name, + points_collected=len(power_points), + ) + ) + + async def purge_all(self, retention_days: int = 90) -> None: + """Purge power points older than retention_days for all devices. + + Iterates all profiles and their devices, purging historical data that + exceeds the retention window. + """ + cutoff = Timestamp(datetime.now() - timedelta(days=retention_days)) + profiles = self.home_loads_repo.get_all() + if not profiles: + return + + for profile in profiles: + for device in profile.devices: + try: + purged = self.home_load_history_repo.purge_before(device.id, cutoff) + except Exception as e: + if self.logger: + self.logger.error(f"Error purging history for device '{device.name}': {e}") + continue + + if purged > 0: + if self.logger: + self.logger.debug( + f"Purged {purged} power points for device '{device.name}' " + f"(older than {retention_days} days)." + ) + if self._event_bus: + await self._event_bus.publish( + LoadConsumptionHistoryPurgedEvent( + device_id=device.id, + device_name=device.name, + points_purged=purged, + ) + ) + + def get_device_history(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + """Retrieve stored power points for a device in a time window.""" + return self.home_load_history_repo.get_power_points(device_id, start, end) + + def clear_device_history(self, device_id: EntityId) -> int: + """Delete all stored power points for a device.""" + removed = self.home_load_history_repo.clear_device_history(device_id) + if self.logger: + self.logger.info(f"Cleared {removed} power points for device {device_id}.") + return removed + + async def collect_devices(self, device_ids: List[EntityId], lookback_hours: int = 24) -> None: + """Collect power points for the specified devices only.""" + profiles = self.home_loads_repo.get_all() + if not profiles: + return + + target_ids = set(device_ids) + for profile in profiles: + for device in profile.devices: + if device.id not in target_ids: + continue + if not device.energy_load_history_provider_id: + if self.logger: + self.logger.warning(f"Device '{device.name}' has no history provider configured. Skipping.") + continue + await self._collect_for_device( + device_id=device.id, + device_name=device.name, + provider_id=device.energy_load_history_provider_id, + lookback_hours=lookback_hours, + ) diff --git a/edge_mining/application/services/load_forecast_training_service.py b/edge_mining/application/services/load_forecast_training_service.py new file mode 100644 index 0000000..f3a8058 --- /dev/null +++ b/edge_mining/application/services/load_forecast_training_service.py @@ -0,0 +1,405 @@ +"""Service for training ML forecast models on collected home load history.""" + +import pickle +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from edge_mining.adapters.domain.home_load.forecast_providers.features import ( + fill_missing_hours, + intervals_to_hourly_series, + prepare_supervised_dataset, +) +from edge_mining.adapters.domain.home_load.history_providers.helpers import group_power_points_into_intervals +from edge_mining.application.interfaces import LoadForecastTrainingServiceInterface +from edge_mining.domain.common import EntityId, Timestamp +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.entities import LoadConsumptionModel +from edge_mining.domain.home_load.ports import ( + EnergyLoadHistoryRepository, + HomeLoadsProfileRepository, + LoadConsumptionModelRepository, +) +from edge_mining.domain.home_load.value_objects import LoadEnergyConsumption +from edge_mining.shared.logging.port import LoggerPort + + +class LoadForecastModelTrainingService(LoadForecastTrainingServiceInterface): + """Trains ML models (Statsmodels, XGBoost) on historical home load data. + + Designed to be run nightly via the scheduler. For each enabled device + that has enough history, trains both a Holt-Winters and an XGBoost model, + evaluates them against a holdout set, and promotes the best one to active. + """ + + def __init__( + self, + home_loads_repo: HomeLoadsProfileRepository, + history_repo: EnergyLoadHistoryRepository, + model_repo: LoadConsumptionModelRepository, + logger: Optional[LoggerPort] = None, + ): + self._home_loads_repo = home_loads_repo + self._history_repo = history_repo + self._model_repo = model_repo + self._logger = logger + + async def train_all(self, weeks_lookback: int = 8) -> None: + """Train models for every device that has sufficient history.""" + profiles = self._home_loads_repo.get_all() + if not profiles: + if self._logger: + self._logger.debug("No home load profiles found. Skipping training.") + return + + for profile in profiles: + for device in profile.devices: + if not device.enabled: + continue + try: + await self._train_for_device(device.id, device.name, weeks_lookback) + except Exception as exc: + if self._logger: + self._logger.error(f"Training failed for device '{device.name}': {exc}") + + async def train_device(self, device_id: EntityId, weeks_lookback: int = 8) -> None: + """Train models for a single device identified by device_id.""" + profiles = self._home_loads_repo.get_all() + device_name: Optional[str] = None + for profile in profiles: + for device in profile.devices: + if device.id == device_id: + device_name = device.name + break + if device_name is not None: + break + + if device_name is None: + if self._logger: + self._logger.warning(f"Device {device_id} not found in any profile. Skipping training.") + return + + await self._train_for_device(device_id, device_name, weeks_lookback) + + def get_models(self, device_id: Optional[EntityId] = None) -> List[LoadConsumptionModel]: + """Retrieve trained models, optionally filtered by device.""" + return self._model_repo.get_all(device_id) + + def delete_model(self, model_id: EntityId) -> None: + """Delete a trained model by ID.""" + self._model_repo.remove(model_id) + + async def _train_for_device( + self, + device_id: EntityId, + device_name: str, + weeks_lookback: int, + ) -> None: + """Train HW + XGBoost models for one device, promote the better one.""" + now = Timestamp(datetime.now(timezone.utc)) + lookback_start = Timestamp(now - timedelta(weeks=weeks_lookback)) + + power_points = self._history_repo.get_power_points(device_id, lookback_start, now) + if len(power_points) < 48 * 2: # at least 48 hours of data for train+holdout + if self._logger: + self._logger.debug( + f"Insufficient history for device '{device_name}' ({len(power_points)} points). Skipping training." + ) + return + + # Build LoadEnergyConsumption from power points + intervals = group_power_points_into_intervals(power_points) + consumption = LoadEnergyConsumption(timestamp=now, intervals=intervals) + + # Split: last 24h as holdout + holdout_start = Timestamp(now - timedelta(hours=24)) + train_consumption = consumption.in_window(lookback_start, holdout_start) + holdout_consumption = consumption.in_window(holdout_start, now) + + if len(train_consumption.intervals) < 48 or len(holdout_consumption.intervals) < 12: + if self._logger: + self._logger.debug(f"Not enough data after split for device '{device_name}'. Skipping.") + return + + hw_model = self._train_hw(train_consumption, holdout_consumption, device_id, device_name) + xgb_model = self._train_xgb(train_consumption, holdout_consumption, device_id, device_name) + skf_model = self._train_skforecast(train_consumption, holdout_consumption, device_id, device_name) + + # Promote the best model + candidates = [m for m in [hw_model, xgb_model, skf_model] if m is not None and m.mae is not None] + if not candidates: + if self._logger: + self._logger.warning(f"No model trained successfully for device '{device_name}'.") + return + + best = min(candidates, key=lambda m: m.mae) # type: ignore[arg-type] + best.is_active = True + + # Deactivate previous active models for this device + for adapter_type in [ + EnergyLoadForecastProviderAdapter.STATSMODELS, + EnergyLoadForecastProviderAdapter.XGBOOST, + EnergyLoadForecastProviderAdapter.SKFORECAST, + ]: + old = self._model_repo.get_active_model(adapter_type, device_id) + if old is not None: + old.is_active = False + self._model_repo.update(old) + + # Persist all trained models + for model in candidates: + self._model_repo.add(model) + + if self._logger: + self._logger.info( + f"Trained models for device '{device_name}': best={best.adapter_type.value} MAE={best.mae:.2f}" + ) + + def _train_hw( + self, + train: LoadEnergyConsumption, + holdout: LoadEnergyConsumption, + device_id: EntityId, + device_name: str, + ) -> Optional[LoadConsumptionModel]: + """Train Holt-Winters and evaluate on holdout.""" + try: + from statsmodels.tsa.holtwinters import ExponentialSmoothing + except ImportError: + return None + + series = intervals_to_hourly_series(train) + series = fill_missing_hours(series) + powers = [p for _, p in series] + + seasonal_periods = 24 + if len(powers) < seasonal_periods * 2: + return None + + try: + model = ExponentialSmoothing(powers, trend="add", seasonal="add", seasonal_periods=seasonal_periods) + fitted = model.fit(optimized=True) + model_bytes = pickle.dumps(fitted) + + # Evaluate on holdout + holdout_series = intervals_to_hourly_series(holdout) + holdout_series = fill_missing_hours(holdout_series) + holdout_powers = [p for _, p in holdout_series] + + n_eval = min(len(holdout_powers), 24) + if n_eval == 0: + return None + + forecast = fitted.forecast(n_eval) + mae = sum(abs(float(forecast[i]) - holdout_powers[i]) for i in range(n_eval)) / n_eval + rmse = (sum((float(forecast[i]) - holdout_powers[i]) ** 2 for i in range(n_eval)) / n_eval) ** 0.5 + + return LoadConsumptionModel( + device_id=device_id, + adapter_type=EnergyLoadForecastProviderAdapter.STATSMODELS, + trained_at=datetime.now(), + mae=mae, + rmse=rmse, + samples_used=len(powers), + is_active=False, + model_bytes=model_bytes, + ) + except Exception as exc: + if self._logger: + self._logger.warning(f"Holt-Winters training failed for '{device_name}': {exc}") + return None + + def _train_xgb( + self, + train: LoadEnergyConsumption, + holdout: LoadEnergyConsumption, + device_id: EntityId, + device_name: str, + ) -> Optional[LoadConsumptionModel]: + """Train XGBoost and evaluate on holdout.""" + try: + import xgboost as xgb + except ImportError: + return None + + hours_ahead = 3 + X_train, y_train = prepare_supervised_dataset(train, hours_ahead=hours_ahead) + if len(X_train) < 48: + return None + + try: + model = xgb.XGBRegressor( + n_estimators=100, max_depth=6, learning_rate=0.1, objective="reg:squarederror", verbosity=0 + ) + model.fit(X_train, y_train) + model_bytes = pickle.dumps(model) + + # Evaluate on holdout + X_holdout, y_holdout = prepare_supervised_dataset(holdout, hours_ahead=hours_ahead) + if len(X_holdout) < 3: + # If holdout has insufficient supervised pairs, use raw MAE + holdout_series = intervals_to_hourly_series(holdout) + holdout_series = fill_missing_hours(holdout_series) + holdout_powers = [p for _, p in holdout_series] + if not holdout_powers: + return None + # Predict on holdout features (from training data end) + X_eval, y_eval = prepare_supervised_dataset( + LoadEnergyConsumption( + timestamp=holdout.timestamp, + intervals=list(train.intervals) + list(holdout.intervals), + ), + hours_ahead=hours_ahead, + ) + # Use last portion as holdout + n_eval = min(len(holdout_powers), len(X_eval)) + if n_eval == 0: + return None + X_eval = X_eval[-n_eval:] + y_eval = y_eval[-n_eval:] + else: + X_eval, y_eval = X_holdout, y_holdout + n_eval = len(y_eval) + + predictions = model.predict(X_eval) + mae = sum(abs(float(predictions[i]) - y_eval[i]) for i in range(n_eval)) / n_eval + rmse = (sum((float(predictions[i]) - y_eval[i]) ** 2 for i in range(n_eval)) / n_eval) ** 0.5 + + return LoadConsumptionModel( + device_id=device_id, + adapter_type=EnergyLoadForecastProviderAdapter.XGBOOST, + trained_at=datetime.now(), + mae=mae, + rmse=rmse, + samples_used=len(X_train), + is_active=False, + model_bytes=model_bytes, + ) + except Exception as exc: + if self._logger: + self._logger.warning(f"XGBoost training failed for '{device_name}': {exc}") + return None + + def _train_skforecast( + self, + train: LoadEnergyConsumption, + holdout: LoadEnergyConsumption, + device_id: EntityId, + device_name: str, + sklearn_model: str = "RandomForestRegressor", + num_lags: int = 72, + perform_tuning: bool = True, + tuning_trials: int = 20, + ) -> Optional[LoadConsumptionModel]: + """Train skforecast ForecasterRecursive and evaluate on holdout. + + If ``perform_tuning`` is True and Optuna is available, Bayesian + hyperparameter tuning is run after the initial fit to find the + best combination of model hyperparameters and lag count. + """ + try: + import pandas as pd_ + from skforecast.recursive import ForecasterRecursive as FR + + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + _resolve_sklearn_model, + ) + except ImportError: + return None + + series = intervals_to_hourly_series(train) + series = fill_missing_hours(series) + powers = [p for _, p in series] + + if len(powers) < num_lags + 24: + return None + + try: + y = pd_.Series(powers, name="power") + tuning_params: Optional[dict] = None + + # --- Optuna tuning (optional) --- + if perform_tuning and len(powers) >= num_lags + 48 + 24: + try: + best_params, tuned_forecaster = SkforecastForecastProvider.tune( + y_series=y, + sklearn_model_name=sklearn_model, + num_lags=num_lags, + steps=24, + n_trials=tuning_trials, + ) + tuning_params = best_params + model_bytes = pickle.dumps(tuned_forecaster) + forecaster = tuned_forecaster + + if self._logger: + self._logger.debug(f"Optuna tuning completed for '{device_name}': {best_params}") + except Exception as tune_exc: + if self._logger: + self._logger.warning(f"Optuna tuning failed for '{device_name}', using base fit: {tune_exc}") + # Fallback to base fit + regressor = _resolve_sklearn_model(sklearn_model) + forecaster = FR(estimator=regressor, lags=num_lags) + forecaster.fit(y=y) + model_bytes = pickle.dumps(forecaster) + else: + # Base fit without tuning + regressor = _resolve_sklearn_model(sklearn_model) + forecaster = FR(estimator=regressor, lags=num_lags) + forecaster.fit(y=y) + model_bytes = pickle.dumps(forecaster) + + # Evaluate on holdout + holdout_series = intervals_to_hourly_series(holdout) + holdout_series = fill_missing_hours(holdout_series) + holdout_powers = [p for _, p in holdout_series] + + n_eval = min(len(holdout_powers), 24) + if n_eval == 0: + return None + + predictions = forecaster.predict(steps=n_eval) + pred_list = predictions.tolist() + mae = sum(abs(float(pred_list[i]) - holdout_powers[i]) for i in range(n_eval)) / n_eval + rmse = (sum((float(pred_list[i]) - holdout_powers[i]) ** 2 for i in range(n_eval)) / n_eval) ** 0.5 + + # --- Rolling-window backtesting --- + backtest_mae: Optional[float] = None + backtest_rmse: Optional[float] = None + backtest_folds: int = 0 + try: + bt_result = SkforecastForecastProvider.backtest( + forecaster=forecaster, + y_series=y, + steps=24, + folds=3, + ) + backtest_mae = bt_result.get("backtest_mae") + backtest_rmse = bt_result.get("backtest_rmse") + backtest_folds = bt_result.get("backtest_folds", 0) + if self._logger: + self._logger.debug( + f"Backtesting for '{device_name}': MAE={backtest_mae}, RMSE={backtest_rmse}, folds={backtest_folds}" + ) + except Exception as bt_exc: + if self._logger: + self._logger.warning(f"Backtesting failed for '{device_name}': {bt_exc}") + + return LoadConsumptionModel( + device_id=device_id, + adapter_type=EnergyLoadForecastProviderAdapter.SKFORECAST, + trained_at=datetime.now(), + mae=mae, + rmse=rmse, + samples_used=len(powers), + is_active=False, + model_bytes=model_bytes, + tuning_params=tuning_params, + backtest_mae=backtest_mae, + backtest_rmse=backtest_rmse, + backtest_folds=backtest_folds, + ) + except Exception as exc: + if self._logger: + self._logger.warning(f"Skforecast training failed for '{device_name}': {exc}") + return None diff --git a/edge_mining/application/services/optimization_service.py b/edge_mining/application/services/optimization_service.py index d4232ff..fd1ff76 100644 --- a/edge_mining/application/services/optimization_service.py +++ b/edge_mining/application/services/optimization_service.py @@ -8,7 +8,8 @@ """ import asyncio -from typing import List, Optional +from datetime import datetime, timedelta +from typing import Dict, List, Optional from edge_mining.application.interfaces import ( AdapterServiceInterface, @@ -16,15 +17,26 @@ OptimizationServiceInterface, SunFactoryInterface, ) -from edge_mining.domain.common import EntityId +from edge_mining.domain.common import EntityId, Timestamp, WattHours from edge_mining.domain.energy.entities import EnergySource from edge_mining.domain.energy.events import EnergyStateSnapshotUpdatedEvent from edge_mining.domain.energy.ports import EnergyMonitorPort, EnergySourceRepository from edge_mining.domain.energy.value_objects import EnergyStateSnapshot from edge_mining.domain.forecast.aggregate_root import Forecast from edge_mining.domain.forecast.ports import ForecastProviderPort -from edge_mining.domain.home_load.ports import HomeForecastProviderPort -from edge_mining.domain.home_load.value_objects import ConsumptionForecast +from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile +from edge_mining.domain.home_load.entities import LoadDevice +from edge_mining.domain.home_load.ports import ( + EnergyLoadForecastProviderPort, + EnergyLoadHistoryProviderPort, + HomeLoadsProfileRepository, +) +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadsConsumption, + LoadDeviceConsumption, + LoadEnergyConsumption, +) from edge_mining.domain.miner.aggregate_roots import Miner from edge_mining.domain.miner.common import MinerFeatureType, MinerStatus from edge_mining.domain.miner.events import MinerStateChangedEvent @@ -69,10 +81,13 @@ def __init__( energy_source_repo: EnergySourceRepository, policy_repo: OptimizationPolicyRepository, miner_repo: MinerRepository, + home_loads_repo: HomeLoadsProfileRepository, adapter_service: AdapterServiceInterface, sun_factory: SunFactoryInterface, event_bus: Optional[EventBusInterface] = None, logger: Optional[LoggerPort] = None, + forecast_mix_alpha: float = 0.5, + forecast_mix_beta: float = 0.5, ): # Domains @@ -81,6 +96,7 @@ def __init__( self.energy_source_repo = energy_source_repo self.policy_repo = policy_repo self.miner_repo = miner_repo + self.home_loads_repo = home_loads_repo # Infrastructure self.sun_factory = sun_factory @@ -88,6 +104,135 @@ def __init__( self._event_bus = event_bus self.logger = logger + # Forecast blending (α/β mix of forecast with last real measurement) + self.forecast_mix_alpha = forecast_mix_alpha + self.forecast_mix_beta = forecast_mix_beta + + @staticmethod + def _sum_consumptions(consumptions: List[LoadEnergyConsumption]) -> LoadEnergyConsumption: + """Sum a list of LoadEnergyConsumption by matching (start, end) intervals.""" + now_ts = Timestamp(datetime.now()) + if not consumptions: + return LoadEnergyConsumption(timestamp=now_ts, intervals=[]) + + buckets: Dict[tuple, List[HomeLoadEnergyInterval]] = {} + for consumption in consumptions: + for interval in consumption.intervals: + buckets.setdefault((interval.start, interval.end), []).append(interval) + + merged: List[HomeLoadEnergyInterval] = [] + for (start, end), intervals in sorted(buckets.items(), key=lambda kv: kv[0][0]): + total_energy = WattHours(sum(float(i.energy) for i in intervals if i.energy is not None)) + power_points = [p for i in intervals for p in i.power_points] + merged.append( + HomeLoadEnergyInterval( + start=start, + end=end, + energy=total_energy if total_energy else None, + power_points=power_points, + ) + ) + + return LoadEnergyConsumption(timestamp=now_ts, intervals=merged) + + async def _build_home_loads_consumption( + self, + home_loads_profile: Optional[HomeLoadsProfile], + forecast_providers: Dict[EntityId, EnergyLoadForecastProviderPort], + history_providers: Dict[EntityId, EnergyLoadHistoryProviderPort], + unit_name: str, + ) -> Optional[HomeLoadsConsumption]: + """Assemble per-device history+forecast and their household totals. + + For each device, history is fetched from its history provider (if any) + over a 24-hour look-back window. Forecast is obtained by calling each + device's forecast provider with the device history. + """ + if home_loads_profile is None: + return None + + now = Timestamp(datetime.now()) + window_start = Timestamp(now - timedelta(hours=24)) + empty_consumption = LoadEnergyConsumption(timestamp=now, intervals=[]) + + per_device: List[LoadDeviceConsumption] = [] + for device in home_loads_profile.devices: + # --- History --- + device_history = empty_consumption + history_provider = history_providers.get(device.id) + if history_provider is not None: + try: + intervals = await history_provider.get_history(window_start, now) + if intervals: + device_history = LoadEnergyConsumption(timestamp=now, intervals=intervals) + elif self.logger: + self.logger.debug(f"[HomeLoad] History provider for '{device.name}' returned empty intervals") + except Exception as e: + if self.logger: + self.logger.warning( + f"Error getting load history for device '{device.name}' " + f"in optimization unit '{unit_name}': {e}" + ) + elif self.logger: + self.logger.debug( + f"[HomeLoad] No history provider for device '{device.name}' " + f"(history_provider_id={device.energy_load_history_provider_id})" + ) + + # --- Forecast --- + device_forecast = empty_consumption + forecast_provider = forecast_providers.get(device.id) + if forecast_provider is not None: + try: + result = forecast_provider.get_consumption_forecast(device_history) + if result is not None: + device_forecast = result + elif self.logger: + self.logger.debug(f"[HomeLoad] Forecast provider for '{device.name}' returned None") + except Exception as e: + if self.logger: + self.logger.warning( + f"Error getting load forecast for device '{device.name}' " + f"in optimization unit '{unit_name}': {e}" + ) + elif self.logger: + self.logger.debug( + f"[HomeLoad] No forecast provider for device '{device.name}' " + f"(forecast_provider_id={device.energy_load_forecast_provider_id})" + ) + + # --- Mix forecast with last real measurement (α/β blending) --- + if device_forecast.intervals and device_history.intervals: + last_real_power = device_history.intervals[-1].avg_power + device_forecast = LoadEnergyConsumption.mix( + device_forecast, + last_real_power, + alpha=self.forecast_mix_alpha, + beta=self.forecast_mix_beta, + ) + + per_device.append(self._make_device_consumption(device, device_history, device_forecast)) + + return HomeLoadsConsumption( + per_device=per_device, + total_history=self._sum_consumptions([d.history for d in per_device]), + total_forecast=self._sum_consumptions([d.forecast for d in per_device]), + ) + + @staticmethod + def _make_device_consumption( + device: LoadDevice, + history: LoadEnergyConsumption, + forecast: LoadEnergyConsumption, + ) -> LoadDeviceConsumption: + return LoadDeviceConsumption( + device_id=device.id, + device_name=device.name, + device_category=device.category, + history=history, + forecast=forecast, + ) + async def _build_mining_performance_snapshot( self, tracker: MiningPerformanceTrackerPort, @@ -107,8 +252,7 @@ async def _build_mining_performance_snapshot( except Exception as e: if self.logger: self.logger.warning( - f"Error getting mining performance tracker " - f"for optimization unit '{optimization_unit_name}': {e}" + f"Error getting mining performance tracker for optimization unit '{optimization_unit_name}': {e}" ) return None @@ -199,22 +343,55 @@ async def get_decisional_context(self, optimization_unit_id: EntityId) -> Option f"Skipping optimization unit." ) - # --- Home Forecast Provider --- - home_forecast_provider: Optional[HomeForecastProviderPort] = None - if optimization_unit.home_forecast_provider_id: - home_forecast_provider = self.adapter_service.get_home_load_forecast_provider( - optimization_unit.home_forecast_provider_id - ) - # Home forecast provider is optional, so log a warning if it's missing but - # continue - if not home_forecast_provider: - if self.logger: - self.logger.warning( - f"Home forecast provider for " - f"optimization unit '{optimization_unit.name}' " - f"(Config ID: {optimization_unit.home_forecast_provider_id}) " - "not found. Skipping forecast provider." - ) + # --- Home Loads --- + home_loads_profile: Optional[HomeLoadsProfile] = None + if optimization_unit.home_loads_profile: + profile = self.home_loads_repo.get_by_id(optimization_unit.home_loads_profile) + if profile: + home_loads_profile = profile + + # --- Home Loads Forecast Provider --- + energy_load_forecast_providers: Dict[EntityId, EnergyLoadForecastProviderPort] = {} + if home_loads_profile and home_loads_profile.devices: + for load_device in home_loads_profile.devices: + if load_device.energy_load_forecast_provider_id: + energy_load_forecast_provider = self.adapter_service.get_home_load_forecast_provider( + load_device.energy_load_forecast_provider_id + ) + # Energy load forecast provider is optional, so log a warning if it's + # missing but continue + if not energy_load_forecast_provider: + if self.logger: + self.logger.warning( + f"Energy load forecast provider for " + f"load device '{load_device.name}' of " + f"optimization unit '{optimization_unit.name}' " + f"(Config ID: {load_device.energy_load_forecast_provider_id}) " + "not found. Skipping forecast provider." + ) + + if energy_load_forecast_provider: + energy_load_forecast_providers[load_device.id] = energy_load_forecast_provider + + # --- Home Loads History Provider --- + energy_load_history_providers: Dict[EntityId, EnergyLoadHistoryProviderPort] = {} + if home_loads_profile and home_loads_profile.devices: + for load_device in home_loads_profile.devices: + if load_device.energy_load_history_provider_id: + energy_load_history_provider = await self.adapter_service.get_home_load_history_provider( + load_device.energy_load_history_provider_id, load_device.id + ) + if not energy_load_history_provider: + if self.logger: + self.logger.warning( + f"Energy load history provider for " + f"load device '{load_device.name}' of " + f"optimization unit '{optimization_unit.name}' " + f"(Config ID: {load_device.energy_load_history_provider_id}) " + "not found. Skipping history provider." + ) + else: + energy_load_history_providers[load_device.id] = energy_load_history_provider # --- Energy State --- if energy_source and energy_monitor: @@ -254,22 +431,13 @@ async def get_decisional_context(self, optimization_unit_id: EntityId) -> Option f"Error getting solar forecast for optimization unit '{optimization_unit.name}': {e}" ) - # --- Home Load Forecast --- - home_load_forecast: Optional[ConsumptionForecast] = None - if home_forecast_provider: - try: - # TODO: Provide parameters if needed - home_load_forecast = home_forecast_provider.get_home_consumption_forecast() - except Exception as e: - if self.logger: - self.logger.warning( - f"Error getting home load forecast for optimization unit '{optimization_unit.name}': {e}" - ) - else: - if self.logger: - self.logger.info( - f"No home load forecast provider configured for optimization unit '{optimization_unit.name}'." - ) + # --- Home Load Consumption (per-device history + forecast) --- + home_load = await self._build_home_loads_consumption( + home_loads_profile, + energy_load_forecast_providers, + energy_load_history_providers, + optimization_unit.name, + ) # --- Target Miners --- # Process only the first enabled miner in the optimization unit @@ -368,9 +536,16 @@ async def get_decisional_context(self, optimization_unit_id: EntityId) -> Option mining_performance: Optional[MiningPerformanceSnapshot] = None mining_performance_tracker: Optional[MiningPerformanceTrackerPort] = None if optimization_unit.performance_tracker_id: - mining_performance_tracker = await self.adapter_service.get_mining_performance_tracker( - optimization_unit.performance_tracker_id - ) + try: + mining_performance_tracker = await self.adapter_service.get_mining_performance_tracker( + optimization_unit.performance_tracker_id + ) + except Exception as e: + if self.logger: + self.logger.error( + f"Error getting mining performance tracker for optimization unit " + f"'{optimization_unit.name}': {e}" + ) # Mining performance tracker is optional, so log a warning if it's missing # but continue if not mining_performance_tracker: @@ -400,7 +575,7 @@ async def get_decisional_context(self, optimization_unit_id: EntityId) -> Option energy_source=energy_source, energy_state=energy_state, forecast=forecast_data, - home_load_forecast=home_load_forecast, + home_load=home_load, mining_performance=mining_performance, sun=sun, miner=miner, @@ -531,28 +706,62 @@ async def _process_unit(self, optimization_unit: EnergyOptimizationUnit): f"Skipping optimization unit." ) - # --- Home Forecast Provider --- - home_forecast_provider: Optional[HomeForecastProviderPort] = None - if optimization_unit.home_forecast_provider_id: - try: - home_forecast_provider = self.adapter_service.get_home_load_forecast_provider( - optimization_unit.home_forecast_provider_id - ) - except Exception as e: - if self.logger: - self.logger.error( - f"Error getting home forecast provider for optimization unit '{optimization_unit.name}': {e}" + # --- Home Loads --- + home_loads_profile: Optional[HomeLoadsProfile] = None + if optimization_unit.home_loads_profile: + home_loads_profile = self.home_loads_repo.get_by_id(optimization_unit.home_loads_profile) + + # --- Energy Load Forecast Providers (per LoadDevice) --- + energy_load_forecast_providers: Dict[EntityId, EnergyLoadForecastProviderPort] = {} + if home_loads_profile and home_loads_profile.devices: + for load_device in home_loads_profile.devices: + if not load_device.energy_load_forecast_provider_id: + continue + try: + provider = self.adapter_service.get_home_load_forecast_provider( + load_device.energy_load_forecast_provider_id + ) + except Exception as e: + provider = None + if self.logger: + self.logger.error( + f"Error getting energy load forecast provider for load device " + f"'{load_device.name}' in optimization unit '{optimization_unit.name}': {e}" + ) + if provider: + energy_load_forecast_providers[load_device.id] = provider + elif self.logger: + self.logger.warning( + f"Energy load forecast provider for load device '{load_device.name}' " + f"(Config ID: {load_device.energy_load_forecast_provider_id}) not found. " + f"Skipping forecast provider for this device." + ) + + # --- Energy Load History Providers (per LoadDevice) --- + energy_load_history_providers: Dict[EntityId, EnergyLoadHistoryProviderPort] = {} + if home_loads_profile and home_loads_profile.devices: + for load_device in home_loads_profile.devices: + if not load_device.energy_load_history_provider_id: + continue + try: + h_provider = self.adapter_service.get_home_load_history_provider( + load_device.energy_load_history_provider_id, load_device.id + ) + except Exception as e: + h_provider = None + if self.logger: + self.logger.error( + f"Error getting energy load history provider for load device " + f"'{load_device.name}' in optimization unit '{optimization_unit.name}': {e}" + ) + if h_provider: + energy_load_history_providers[load_device.id] = h_provider + elif self.logger: + self.logger.warning( + f"Energy load history provider for load device '{load_device.name}' " + f"(Config ID: {load_device.energy_load_history_provider_id}) not found. " + f"Skipping history provider for this device." ) - # Home forecast provider is optional, so log a warning if it's missing but - # continue - if not home_forecast_provider: - if self.logger: - self.logger.warning( - f"Home forecast provider for " - f"optimization unit '{optimization_unit.name}' " - f"(Config ID: {optimization_unit.home_forecast_provider_id}) " - "not found. Skipping forecast provider." - ) # --- Mining Performance Tracker --- mining_performance_tracker: Optional[MiningPerformanceTrackerPort] = None @@ -637,22 +846,13 @@ async def _process_unit(self, optimization_unit: EnergyOptimizationUnit): f"No solar forecast provider configured for optimization unit '{optimization_unit.name}'." ) - # --- Home Load Forecast --- - home_load_forecast: Optional[ConsumptionForecast] = None - if home_forecast_provider: - try: - # TODO: Provide parameters if needed - home_load_forecast = home_forecast_provider.get_home_consumption_forecast() - except Exception as e: - if self.logger: - self.logger.warning( - f"Error getting home load forecast for optimization unit '{optimization_unit.name}': {e}" - ) - else: - if self.logger: - self.logger.debug( - f"No home load forecast provider configured for optimization unit '{optimization_unit.name}'." - ) + # --- Home Load Consumption (per-device history + forecast) --- + home_load = await self._build_home_loads_consumption( + home_loads_profile, + energy_load_forecast_providers, + energy_load_history_providers, + optimization_unit.name, + ) # --- Target Miners --- # Process each target miner in this optimization unit @@ -681,7 +881,7 @@ async def _process_unit(self, optimization_unit: EnergyOptimizationUnit): energy_source=energy_source, energy_state=energy_state, forecast=forecast_data, - home_load_forecast=home_load_forecast, + home_load=home_load, mining_performance=mining_performance, sun=sun, ) @@ -815,7 +1015,7 @@ async def _process_single_miner_in_unit( energy_source=context.energy_source, energy_state=context.energy_state, forecast=context.forecast, - home_load_forecast=context.home_load_forecast, + home_load=context.home_load, mining_performance=context.mining_performance, sun=context.sun, miner=miner, # Static config diff --git a/edge_mining/bootstrap.py b/edge_mining/bootstrap.py index 9d8532c..ea1db67 100644 --- a/edge_mining/bootstrap.py +++ b/edge_mining/bootstrap.py @@ -17,12 +17,21 @@ SqliteForecastProviderRepository, ) from edge_mining.adapters.domain.home_load.repositories import ( - InMemoryHomeForecastProviderRepository, + InMemoryEnergyLoadForecastProviderRepository, + InMemoryEnergyLoadHistoryProviderRepository, + InMemoryEnergyLoadHistoryRepository, InMemoryHomeLoadsProfileRepository, - SqlAlchemyHomeForecastProviderRepository, + InMemoryLoadConsumptionModelRepository, + SqlAlchemyEnergyLoadForecastProviderRepository, + SqlAlchemyEnergyLoadHistoryProviderRepository, + SqlAlchemyEnergyLoadHistoryRepository, SqlAlchemyHomeLoadsProfileRepository, - SqliteHomeForecastProviderRepository, + SqlAlchemyLoadConsumptionModelRepository, + SqliteEnergyLoadForecastProviderRepository, + SqliteEnergyLoadHistoryProviderRepository, + SqliteEnergyLoadHistoryRepository, SqliteHomeLoadsProfileRepository, + SqliteLoadConsumptionModelRepository, ) from edge_mining.adapters.domain.miner.repositories import ( InMemoryMinerControllerRepository, @@ -70,6 +79,8 @@ from edge_mining.application.interfaces import SunFactoryInterface from edge_mining.application.services.adapter_service import AdapterService from edge_mining.application.services.configuration_service import ConfigurationService +from edge_mining.application.services.home_load_history_service import HomeLoadHistoryService +from edge_mining.application.services.load_forecast_training_service import LoadForecastModelTrainingService from edge_mining.application.services.miner_action_service import MinerActionService from edge_mining.application.services.optimization_service import OptimizationService from edge_mining.domain.energy.ports import ( @@ -78,8 +89,11 @@ ) from edge_mining.domain.forecast.ports import ForecastProviderRepository from edge_mining.domain.home_load.ports import ( - HomeForecastProviderRepository, + EnergyLoadForecastProviderRepository, + EnergyLoadHistoryProviderRepository, + EnergyLoadHistoryRepository, HomeLoadsProfileRepository, + LoadConsumptionModelRepository, ) from edge_mining.domain.miner.ports import MinerControllerRepository, MinerRepository from edge_mining.domain.notification.ports import NotifierRepository @@ -152,7 +166,10 @@ def configure_persistence(logger: LoggerPort, settings: AppSettings) -> Persiste mining_performance_tracker_repo: MiningPerformanceTrackerRepository settings_repo: SettingsRepository home_profile_repo: HomeLoadsProfileRepository - home_forecast_provider_repo: HomeForecastProviderRepository + energy_load_forecast_provider_repo: EnergyLoadForecastProviderRepository + energy_load_history_provider_repo: EnergyLoadHistoryProviderRepository + home_load_history_repo: EnergyLoadHistoryRepository + load_consumption_model_repo: LoadConsumptionModelRepository optimization_unit_repo: EnergyOptimizationUnitRepository external_service_repo: ExternalServiceRepository @@ -168,7 +185,10 @@ def configure_persistence(logger: LoggerPort, settings: AppSettings) -> Persiste mining_performance_tracker_repo = InMemoryMiningPerformanceTrackerRepository() settings_repo = InMemorySettingsRepository() home_profile_repo = InMemoryHomeLoadsProfileRepository() - home_forecast_provider_repo = InMemoryHomeForecastProviderRepository() + energy_load_forecast_provider_repo = InMemoryEnergyLoadForecastProviderRepository() + energy_load_history_provider_repo = InMemoryEnergyLoadHistoryProviderRepository() + home_load_history_repo = InMemoryEnergyLoadHistoryRepository() + load_consumption_model_repo = InMemoryLoadConsumptionModelRepository() optimization_unit_repo = InMemoryOptimizationUnitRepository() external_service_repo = InMemoryExternalServiceRepository() @@ -190,7 +210,10 @@ def configure_persistence(logger: LoggerPort, settings: AppSettings) -> Persiste mining_performance_tracker_repo = SqliteMiningPerformanceTrackerRepository(db=sqlite_db) settings_repo = SqliteSettingsRepository(db=sqlite_db) home_profile_repo = SqliteHomeLoadsProfileRepository(db=sqlite_db) - home_forecast_provider_repo = SqliteHomeForecastProviderRepository(db=sqlite_db) + energy_load_forecast_provider_repo = SqliteEnergyLoadForecastProviderRepository(db=sqlite_db) + energy_load_history_provider_repo = SqliteEnergyLoadHistoryProviderRepository(db=sqlite_db) + home_load_history_repo = SqliteEnergyLoadHistoryRepository(db=sqlite_db) + load_consumption_model_repo = SqliteLoadConsumptionModelRepository(db=sqlite_db) optimization_unit_repo = SqliteOptimizationUnitRepository(db=sqlite_db) external_service_repo = SqliteExternalServiceRepository(db=sqlite_db) @@ -213,7 +236,10 @@ def configure_persistence(logger: LoggerPort, settings: AppSettings) -> Persiste mining_performance_tracker_repo = SqlAlchemyMiningPerformanceTrackerRepository(db=sqlalchemy_db) settings_repo = SqlAlchemySettingsRepository(db=sqlalchemy_db) home_profile_repo = SqlAlchemyHomeLoadsProfileRepository(db=sqlalchemy_db) - home_forecast_provider_repo = SqlAlchemyHomeForecastProviderRepository(db=sqlalchemy_db) + energy_load_forecast_provider_repo = SqlAlchemyEnergyLoadForecastProviderRepository(db=sqlalchemy_db) + energy_load_history_provider_repo = SqlAlchemyEnergyLoadHistoryProviderRepository(db=sqlalchemy_db) + home_load_history_repo = SqlAlchemyEnergyLoadHistoryRepository(db=sqlalchemy_db) + load_consumption_model_repo = SqlAlchemyLoadConsumptionModelRepository(db=sqlalchemy_db) optimization_unit_repo = SqlAlchemyOptimizationUnitRepository(db=sqlalchemy_db) external_service_repo = SqlAlchemyExternalServiceRepository(db=sqlalchemy_db) @@ -259,7 +285,10 @@ def configure_persistence(logger: LoggerPort, settings: AppSettings) -> Persiste miner_controller_repo=miner_controller_repo, forecast_provider_repo=forecast_provider_repo, home_profile_repo=home_profile_repo, - home_forecast_provider_repo=home_forecast_provider_repo, + energy_load_forecast_provider_repo=energy_load_forecast_provider_repo, + energy_load_history_provider_repo=energy_load_history_provider_repo, + home_load_history_repo=home_load_history_repo, + load_consumption_model_repo=load_consumption_model_repo, notifier_repo=notifier_repo, optimization_unit_repo=optimization_unit_repo, policy_repo=policy_repo, @@ -300,11 +329,14 @@ def configure_dependencies(logger: LoggerPort, settings: AppSettings) -> Service miner_repo=persistence_settings.miner_repo, notifier_repo=persistence_settings.notifier_repo, forecast_provider_repo=persistence_settings.forecast_provider_repo, - home_forecast_provider_repo=persistence_settings.home_forecast_provider_repo, + energy_load_forecast_provider_repo=persistence_settings.energy_load_forecast_provider_repo, + energy_load_history_provider_repo=persistence_settings.energy_load_history_provider_repo, + home_load_history_repo=persistence_settings.home_load_history_repo, mining_performance_tracker_repo=persistence_settings.mining_performance_tracker_repo, external_service_repo=persistence_settings.external_service_repo, event_bus=event_bus, logger=logger, + load_consumption_model_repo=persistence_settings.load_consumption_model_repo, ) optimization_service = OptimizationService( @@ -312,10 +344,13 @@ def configure_dependencies(logger: LoggerPort, settings: AppSettings) -> Service energy_source_repo=persistence_settings.energy_source_repo, policy_repo=persistence_settings.policy_repo, miner_repo=persistence_settings.miner_repo, + home_loads_repo=persistence_settings.home_profile_repo, adapter_service=adapter_service, sun_factory=sun_factory, event_bus=event_bus, logger=logger, + forecast_mix_alpha=settings.forecast_mix_alpha, + forecast_mix_beta=settings.forecast_mix_beta, ) miner_action_service = MinerActionService( @@ -332,11 +367,28 @@ def configure_dependencies(logger: LoggerPort, settings: AppSettings) -> Service adapter_service=adapter_service, ) + home_load_history_service = HomeLoadHistoryService( + home_loads_repo=persistence_settings.home_profile_repo, + home_load_history_repo=persistence_settings.home_load_history_repo, + adapter_service=adapter_service, + event_bus=event_bus, + logger=logger, + ) + + load_forecast_training_service = LoadForecastModelTrainingService( + home_loads_repo=persistence_settings.home_profile_repo, + history_repo=persistence_settings.home_load_history_repo, + model_repo=persistence_settings.load_consumption_model_repo, + logger=logger, + ) + services = Services( adapter_service=adapter_service, optimization_service=optimization_service, miner_action_service=miner_action_service, configuration_service=config_service, + home_load_history_service=home_load_history_service, + load_forecast_training_service=load_forecast_training_service, event_bus=event_bus, ) diff --git a/edge_mining/domain/home_load/aggregate_roots.py b/edge_mining/domain/home_load/aggregate_roots.py index aae76af..d54f0a1 100644 --- a/edge_mining/domain/home_load/aggregate_roots.py +++ b/edge_mining/domain/home_load/aggregate_roots.py @@ -4,10 +4,11 @@ """ from dataclasses import dataclass, field -from typing import Dict +from typing import List from edge_mining.domain.common import AggregateRoot, EntityId from edge_mining.domain.home_load.entities import LoadDevice +from edge_mining.domain.home_load.exceptions import HomeLoadsProfileAddDeviceError @dataclass @@ -15,6 +16,29 @@ class HomeLoadsProfile(AggregateRoot): """Aggregate Root for the Home Loads.""" name: str = "Default Home Profile" - devices: Dict[EntityId, LoadDevice] = field(default_factory=dict) - # We might store aggregated historical data or patterns here - # For simplicity now, the forecasting logic is external (in the adapter) + devices: List[LoadDevice] = field(default_factory=list) + + def __post_init__(self) -> None: + """Enforce the device-name uniqueness invariant on construction.""" + seen: set[str] = set() + for device in self.devices: + if device.name in seen: + raise HomeLoadsProfileAddDeviceError(f"Duplicate device name '{device.name}' in profile '{self.name}'.") + seen.add(device.name) + + def add_device(self, device: LoadDevice) -> None: + """Append a device enforcing name uniqueness within this profile.""" + if any(existing.name == device.name for existing in self.devices): + raise HomeLoadsProfileAddDeviceError( + f"A device named '{device.name}' already exists in profile '{self.name}'." + ) + self.devices.append(device) + + def remove_device(self, device_id: EntityId) -> LoadDevice: + """Remove a device by id; raises if not found.""" + for idx, existing in enumerate(self.devices): + if existing.id == device_id: + return self.devices.pop(idx) + from edge_mining.domain.home_load.exceptions import HomeLoadsProfileDeviceNotFoundError + + raise HomeLoadsProfileDeviceNotFoundError(f"Device with id {device_id} not found in profile '{self.name}'.") diff --git a/edge_mining/domain/home_load/common.py b/edge_mining/domain/home_load/common.py index 7c79967..7fcb4c8 100644 --- a/edge_mining/domain/home_load/common.py +++ b/edge_mining/domain/home_load/common.py @@ -2,10 +2,50 @@ Common classes for the Home Load domain of the Edge Mining application. """ +from enum import Enum + from edge_mining.domain.common import AdapterType -class HomeForecastProviderAdapter(AdapterType): - """Types of home forecast provider adapter.""" +class LoadDeviceCategory(Enum): + """ + Categories for load devices based on consumption patterns. + + CONTROLLABLE: + Programmable loads like washing machines or dishwashers. + Consumption concentrated in specific time windows, and the loads have predictable patterns + based on user-selected start times. + CONTINUOUS: + Always-on or semi-continuous loads like fridges or boilers. + Repetitive pattern on hourly/daily basis. They operate almost constantly with activation/deactivation cycles. + SEASONAL: + Weather-dependent loads like heating or air conditioning (heating, AC). + Heavily dependent on season and external temperature. + OCCASIONAL: + Infrequent or irregular usage devices (vacuum cleaner, power tools). + """ + + CONTROLLABLE = "controllable" + CONTINUOUS = "continuous" + SEASONAL = "seasonal" + OCCASIONAL = "occasional" + + +class EnergyLoadForecastProviderAdapter(AdapterType): + """Types of energy load forecast provider adapter.""" + + DUMMY = "dummy" + NAIVE_LAST_HOUR = "naive_last_hour" + NAIVE_PERSISTENCE = "naive_persistence" + SEASONAL_BASELINE = "seasonal_baseline" + SKFORECAST = "skforecast" + STATSMODELS = "statsmodels" + TYPICAL_PROFILE = "typical_profile" + XGBOOST = "xgboost" + + +class EnergyLoadHistoryProviderAdapter(AdapterType): + """Types of energy load history provider adapter.""" DUMMY = "dummy" + HOME_ASSISTANT_API = "home_assistant_api" diff --git a/edge_mining/domain/home_load/entities.py b/edge_mining/domain/home_load/entities.py index cf22b70..6dfcbbb 100644 --- a/edge_mining/domain/home_load/entities.py +++ b/edge_mining/domain/home_load/entities.py @@ -1,11 +1,16 @@ """Collection of Entities for the Home Consumption Analytics domain of the Edge Mining application.""" -from dataclasses import dataclass +from dataclasses import dataclass, field +from datetime import datetime from typing import Optional from edge_mining.domain.common import Entity, EntityId -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter -from edge_mining.shared.interfaces.config import HomeForecastProviderConfig +from edge_mining.domain.home_load.common import ( + EnergyLoadForecastProviderAdapter, + EnergyLoadHistoryProviderAdapter, + LoadDeviceCategory, +) +from edge_mining.shared.interfaces.config import EnergyLoadForecastProviderConfig, EnergyLoadHistoryProviderConfig @dataclass @@ -13,15 +18,51 @@ class LoadDevice(Entity): """Entity for a load device.""" name: str = "" # e.g., "Dishwasher", "EV Charger" - type: str = "" # e.g., "Appliance", "Heating" - # Could store typical consumption patterns here but I'll think about it later + category: LoadDeviceCategory = LoadDeviceCategory.OCCASIONAL + enabled: bool = True # Whether the device is active in the system + + energy_load_forecast_provider_id: Optional[EntityId] = None # Energy load forecast provider to be used + energy_load_history_provider_id: Optional[EntityId] = None # Energy load history provider to be used + + +@dataclass +class EnergyLoadForecastProvider(Entity): + """Entity for a energy load forecast provider.""" + + name: str = "" + adapter_type: EnergyLoadForecastProviderAdapter = EnergyLoadForecastProviderAdapter.DUMMY + config: Optional[EnergyLoadForecastProviderConfig] = None + external_service_id: Optional[EntityId] = None @dataclass -class HomeForecastProvider(Entity): - """Entity for a home forecast provider.""" +class EnergyLoadHistoryProvider(Entity): + """Entity for an energy load history provider.""" name: str = "" - adapter_type: HomeForecastProviderAdapter = HomeForecastProviderAdapter.DUMMY - config: Optional[HomeForecastProviderConfig] = None + adapter_type: EnergyLoadHistoryProviderAdapter = EnergyLoadHistoryProviderAdapter.DUMMY + config: Optional[EnergyLoadHistoryProviderConfig] = None external_service_id: Optional[EntityId] = None + + +@dataclass +class LoadConsumptionModel(Entity): + """Entity for a trained ML model used by ML-based forecast providers. + + Stores model metadata and serialized weights. The forecast provider + adapter loads the model from this entity instead of re-training on + every forecast call. + """ + + device_id: Optional[EntityId] = None # None = aggregate model for all devices + adapter_type: EnergyLoadForecastProviderAdapter = EnergyLoadForecastProviderAdapter.STATSMODELS + trained_at: Optional[datetime] = None + mae: Optional[float] = None # mean absolute error on holdout + rmse: Optional[float] = None # root mean squared error on holdout + samples_used: int = 0 # number of training samples + is_active: bool = False # promoted to production + model_bytes: Optional[bytes] = field(default=None, repr=False) # serialized model (pickle/joblib) + tuning_params: Optional[dict] = field(default=None) # best hyperparameters from Optuna tuning + backtest_mae: Optional[float] = None # MAE from rolling-window backtesting + backtest_rmse: Optional[float] = None # RMSE from rolling-window backtesting + backtest_folds: int = 0 # number of folds used in backtesting diff --git a/edge_mining/domain/home_load/events.py b/edge_mining/domain/home_load/events.py new file mode 100644 index 0000000..5a34a89 --- /dev/null +++ b/edge_mining/domain/home_load/events.py @@ -0,0 +1,24 @@ +"""Home load domain events.""" + +from dataclasses import dataclass +from typing import Optional + +from edge_mining.domain.common import DomainEvent, EntityId + + +@dataclass +class LoadConsumptionHistoryCollectedEvent(DomainEvent): + """Event emitted after collecting power points for a device.""" + + device_id: Optional[EntityId] = None + device_name: str = "" + points_collected: int = 0 + + +@dataclass +class LoadConsumptionHistoryPurgedEvent(DomainEvent): + """Event emitted after purging old power points for a device.""" + + device_id: Optional[EntityId] = None + device_name: str = "" + points_purged: int = 0 diff --git a/edge_mining/domain/home_load/exceptions.py b/edge_mining/domain/home_load/exceptions.py index 095c8a4..4f08d24 100644 --- a/edge_mining/domain/home_load/exceptions.py +++ b/edge_mining/domain/home_load/exceptions.py @@ -9,31 +9,85 @@ class HomeLoadError(DomainError): pass -class HomeForecastError(HomeLoadError): - """Base class for home forecast-specific errors.""" +class HomeLoadsProfileAlreadyExistsError(HomeLoadError): + """Home Loads Profile already exists.""" pass -class HomeForecastProviderError(HomeForecastError): - """Errors related to home forecast provider.""" +class HomeLoadsProfileNotFoundError(HomeLoadError): + """Home Loads Profile not found.""" pass -class HomeForecastProviderNotFoundError(HomeForecastProviderError): - """Home Forecast Provider not found.""" +class HomeLoadsProfileAddDeviceError(HomeLoadError): + """Error adding device to Home Loads Profile.""" pass -class HomeForecastProviderAlreadyExistsError(HomeForecastProviderError): - """Home Forecast Provider already exists.""" +class HomeLoadsProfileDeviceNotFoundError(HomeLoadError): + """Load Device not found in Home Loads Profile.""" pass -class HomeForecastProviderConfigurationError(HomeForecastProviderError): +class HomeLoadsProfileRemoveDeviceError(HomeLoadError): + """Error removing device from Home Loads Profile.""" + + pass + + +class EnergyLoadForecastError(HomeLoadError): + """Base class for energy load forecast-specific errors.""" + + pass + + +class EnergyLoadForecastProviderError(EnergyLoadForecastError): + """Errors related to energy load forecast provider.""" + + pass + + +class EnergyLoadForecastProviderNotFoundError(EnergyLoadForecastProviderError): + """Energy Load Forecast Provider not found.""" + + pass + + +class EnergyLoadForecastProviderAlreadyExistsError(EnergyLoadForecastProviderError): + """Energy Load Forecast Provider already exists.""" + + pass + + +class EnergyLoadForecastProviderConfigurationError(EnergyLoadForecastProviderError): + """Error with the configuration.""" + + pass + + +class EnergyLoadHistoryProviderError(HomeLoadError): + """Errors related to energy load history provider.""" + + pass + + +class EnergyLoadHistoryProviderNotFoundError(EnergyLoadHistoryProviderError): + """Energy Load History Provider not found.""" + + pass + + +class EnergyLoadHistoryProviderAlreadyExistsError(EnergyLoadHistoryProviderError): + """Energy Load History Provider already exists.""" + + pass + + +class EnergyLoadHistoryProviderConfigurationError(EnergyLoadHistoryProviderError): """Error with the configuration.""" pass diff --git a/edge_mining/domain/home_load/ports.py b/edge_mining/domain/home_load/ports.py index 96f4705..91ef4f2 100644 --- a/edge_mining/domain/home_load/ports.py +++ b/edge_mining/domain/home_load/ports.py @@ -3,28 +3,117 @@ from abc import ABC, abstractmethod from typing import List, Optional -from edge_mining.domain.common import EntityId -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter +from edge_mining.domain.common import EntityId, Timestamp from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile -from edge_mining.domain.home_load.entities import HomeForecastProvider -from edge_mining.domain.home_load.value_objects import ConsumptionForecast +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter, EnergyLoadHistoryProviderAdapter +from edge_mining.domain.home_load.entities import ( + EnergyLoadForecastProvider, + EnergyLoadHistoryProvider, + LoadConsumptionModel, +) +from edge_mining.domain.home_load.value_objects import HomeLoadEnergyInterval, HomeLoadPowerPoint, LoadEnergyConsumption -class HomeForecastProviderPort(ABC): - """Port for the Home Forecast Provider.""" +class EnergyLoadHistoryRepository(ABC): + """Port for device-scoped persistence of HomeLoadPowerPoint time series. - def __init__(self, home_forecast_provider_type: HomeForecastProviderAdapter): - """Initialize the HomeForecast Provider.""" - self.home_forecast_provider_type = home_forecast_provider_type + Every operation is scoped to a single ``LoadDevice`` via its ``device_id``: + the repository supports multiple devices as independent, per-key streams. + """ @abstractmethod - def get_home_consumption_forecast(self, hours_ahead: int = 3) -> Optional[ConsumptionForecast]: + def add_power_point(self, device_id: EntityId, power_point: HomeLoadPowerPoint) -> None: + """Append a single power point for the given device.""" + raise NotImplementedError + + @abstractmethod + def add_power_points(self, device_id: EntityId, power_points: List[HomeLoadPowerPoint]) -> None: + """Append multiple power points for the given device in one batch.""" + raise NotImplementedError + + @abstractmethod + def get_power_points(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + """Retrieve power points for ``device_id`` within the window [start, end).""" + raise NotImplementedError + + @abstractmethod + def get_latest_timestamp(self, device_id: EntityId) -> Optional[Timestamp]: + """Return the newest timestamp stored for ``device_id``, or None if empty. + + Used by ingestion pipelines to resume fetching from the last known point + and by the rule engine to evaluate staleness. + """ + raise NotImplementedError + + @abstractmethod + def purge_before(self, device_id: EntityId, timestamp: Timestamp) -> int: + """Delete all power points for ``device_id`` with timestamp < ``timestamp``. + + Returns the number of rows deleted (useful for retention metrics). + """ + raise NotImplementedError + + @abstractmethod + def remove_power_points_by_time_range(self, device_id: EntityId, start: Timestamp, end: Timestamp) -> None: + """Remove all power points for ``device_id`` within the window [start, end).""" + raise NotImplementedError + + @abstractmethod + def clear_device_history(self, device_id: EntityId) -> int: + """Delete all power points for ``device_id``. + + Returns the number of rows deleted. """ - Provides an aggregated forecast of home energy consumption - for the specified period. Returns average Watts or a profile? - For now, let's return an estimated *average* Watts needed soon. - Refine later based on how OptimizationPolicy uses it. + raise NotImplementedError + + +class EnergyLoadHistoryProviderPort(ABC): + """Port for retrieving historical energy load consumption data for a single device. + + The port is device-scoped: each provider instance is bound at construction + time to the ``LoadDevice`` it covers. The underlying persistence (cache or + local repo) is an infrastructure concern of the concrete adapter — it is + NOT exposed on the port contract, so domain code can rely on history + providers without knowing whether they cache, stream or query live. + """ + + def __init__(self, device_id: EntityId, provider_type: EnergyLoadHistoryProviderAdapter): + """Initialize the EnergyLoadHistory Provider bound to ``device_id``.""" + self.device_id = device_id + self.provider_type = provider_type + + @abstractmethod + async def get_power_points(self, start: Timestamp, end: Timestamp) -> List[HomeLoadPowerPoint]: + """Retrieve raw power points for this device in the window [start, end).""" + raise NotImplementedError + + @abstractmethod + async def get_history(self, start: Timestamp, end: Timestamp) -> List[HomeLoadEnergyInterval]: + """Retrieve consumption intervals (typically 1h buckets) for this device.""" + raise NotImplementedError + + +class EnergyLoadForecastProviderPort(ABC): + """Port for the Energy Load Forecast Provider.""" + + def __init__(self, forecast_provider_type: EnergyLoadForecastProviderAdapter): + """Initialize the EnergyLoadForecast Provider.""" + self.forecast_provider_type = forecast_provider_type + + @property + def min_required_history_hours(self) -> int: + """Minimum hours of historical data required for this provider to produce a forecast. + + Providers that need more history should override this property. + Returns 0 by default (no minimum requirement). """ + return 0 + + @abstractmethod + def get_consumption_forecast( + self, consumption_history: LoadEnergyConsumption, hours_ahead: int = 3 + ) -> Optional[LoadEnergyConsumption]: + """Provide an aggregated forecast of load energy consumption based on the given history.""" raise NotImplementedError @@ -32,49 +121,139 @@ class HomeLoadsProfileRepository(ABC): """Port for the Home Loads Profile Repository.""" @abstractmethod - def get_profile( - self, - ) -> Optional[HomeLoadsProfile]: # Assuming single profile for now - """Get the home loads profile.""" + def add(self, profile: HomeLoadsProfile) -> None: + """Adds a new home loads profile to the repository.""" raise NotImplementedError @abstractmethod - def save_profile(self, profile: HomeLoadsProfile) -> None: - """Save the home loads profile.""" + def get_by_id(self, profile_id: EntityId) -> Optional[HomeLoadsProfile]: + """Retrieves an home loads profile by its ID.""" raise NotImplementedError + @abstractmethod + def get_all(self) -> List[HomeLoadsProfile]: + """Retrieves all home loads profiles in the repository.""" + raise NotImplementedError -class HomeForecastProviderRepository(ABC): - """Port for the Home Forecast Provider Repository.""" + @abstractmethod + def update(self, profile: HomeLoadsProfile) -> None: + """Updates the state of an existing home loads profile in the repository.""" + raise NotImplementedError @abstractmethod - def add(self, home_forecast_provider: HomeForecastProvider) -> None: - """Adds a new home forecast provider to the repository.""" + def remove(self, profile_id: EntityId) -> None: + """Removes an home loads profile from the repository.""" raise NotImplementedError @abstractmethod - def get_by_id(self, home_forecast_provider_id: EntityId) -> Optional[HomeForecastProvider]: - """Retrieves a home forecast provider by its ID.""" + def get_by_energy_load_forecast_provider_id(self, provider_id: EntityId) -> List[HomeLoadsProfile]: + """Retrieves profiles whose LoadDevices reference the given energy load forecast provider.""" raise NotImplementedError + +class EnergyLoadForecastProviderRepository(ABC): + """Port for the Energy Load Forecast Provider Repository.""" + @abstractmethod - def get_all(self) -> List[HomeForecastProvider]: - """Retrieves all home forecast providers in the repository.""" + def add(self, energy_load_forecast_provider: EnergyLoadForecastProvider) -> None: + """Adds a new energy load forecast provider to the repository.""" raise NotImplementedError @abstractmethod - def update(self, home_forecast_provider: HomeForecastProvider) -> None: - """Updates the state of an existing home forecast provider in the repository.""" + def get_by_id(self, energy_load_forecast_provider_id: EntityId) -> Optional[EnergyLoadForecastProvider]: + """Retrieves an energy load forecast provider by its ID.""" raise NotImplementedError @abstractmethod - def remove(self, home_forecast_provider_id: EntityId) -> None: - """Removes a home forecast provider from the repository.""" + def get_all(self) -> List[EnergyLoadForecastProvider]: + """Retrieves all energy load forecast providers in the repository.""" raise NotImplementedError @abstractmethod - def get_by_external_service_id(self, external_service_id: EntityId) -> List[HomeForecastProvider]: + def update(self, energy_load_forecast_provider: EnergyLoadForecastProvider) -> None: + """Updates the state of an existing energy load forecast provider in the repository.""" + raise NotImplementedError + + @abstractmethod + def remove(self, energy_load_forecast_provider_id: EntityId) -> None: + """Removes an energy load forecast provider from the repository.""" + raise NotImplementedError + + @abstractmethod + def get_by_external_service_id(self, external_service_id: EntityId) -> List[EnergyLoadForecastProvider]: """ - Retrieves all home forecast providers associated with a specific external service ID. + Retrieves all energy load forecast providers associated with a specific external service ID. """ raise NotImplementedError + + +class EnergyLoadHistoryProviderRepository(ABC): + """Port for the Energy Load History Provider Repository.""" + + @abstractmethod + def add(self, energy_load_history_provider: EnergyLoadHistoryProvider) -> None: + """Adds a new energy load history provider to the repository.""" + raise NotImplementedError + + @abstractmethod + def get_by_id(self, energy_load_history_provider_id: EntityId) -> Optional[EnergyLoadHistoryProvider]: + """Retrieves an energy load history provider by its ID.""" + raise NotImplementedError + + @abstractmethod + def get_all(self) -> List[EnergyLoadHistoryProvider]: + """Retrieves all energy load history providers in the repository.""" + raise NotImplementedError + + @abstractmethod + def update(self, energy_load_history_provider: EnergyLoadHistoryProvider) -> None: + """Updates the state of an existing energy load history provider.""" + raise NotImplementedError + + @abstractmethod + def remove(self, energy_load_history_provider_id: EntityId) -> None: + """Removes an energy load history provider from the repository.""" + raise NotImplementedError + + @abstractmethod + def get_by_external_service_id(self, external_service_id: EntityId) -> List[EnergyLoadHistoryProvider]: + """Retrieves all energy load history providers linked to a specific external service.""" + raise NotImplementedError + + +class LoadConsumptionModelRepository(ABC): + """Port for persistence of trained LoadConsumptionModel instances.""" + + @abstractmethod + def add(self, model: LoadConsumptionModel) -> None: + """Persist a newly trained model.""" + raise NotImplementedError + + @abstractmethod + def get_by_id(self, model_id: EntityId) -> Optional[LoadConsumptionModel]: + """Retrieve a model by ID.""" + raise NotImplementedError + + @abstractmethod + def get_active_model( + self, + adapter_type: EnergyLoadForecastProviderAdapter, + device_id: Optional[EntityId] = None, + ) -> Optional[LoadConsumptionModel]: + """Retrieve the currently active (promoted) model for a given adapter type and device.""" + raise NotImplementedError + + @abstractmethod + def get_all(self, device_id: Optional[EntityId] = None) -> List[LoadConsumptionModel]: + """Retrieve all models, optionally filtered by device_id.""" + raise NotImplementedError + + @abstractmethod + def update(self, model: LoadConsumptionModel) -> None: + """Update an existing model (e.g. promote to active).""" + raise NotImplementedError + + @abstractmethod + def remove(self, model_id: EntityId) -> None: + """Remove a model.""" + raise NotImplementedError diff --git a/edge_mining/domain/home_load/value_objects.py b/edge_mining/domain/home_load/value_objects.py index 81ba8d1..93ddf1a 100644 --- a/edge_mining/domain/home_load/value_objects.py +++ b/edge_mining/domain/home_load/value_objects.py @@ -1,16 +1,283 @@ """Collection of Value Objects for the Home Consumption Analytics domain of the Edge Mining application.""" from dataclasses import dataclass, field -from datetime import datetime -from typing import Dict +from datetime import datetime, timedelta, timezone +from typing import List, Optional -from edge_mining.domain.common import Timestamp, ValueObject, Watts +from edge_mining.domain.common import EntityId, Timestamp, ValueObject, WattHours, Watts +from edge_mining.domain.home_load.common import LoadDeviceCategory @dataclass(frozen=True) -class ConsumptionForecast(ValueObject): - """Value Object for a consumption forecast.""" +class HomeLoadPowerPoint(ValueObject): + """Value Object for a single home loads power consumption point.""" - # Predicted consumption for a future period - predicted_watts: Dict[Timestamp, Watts] = field(default_factory=dict) - generated_at: Timestamp = field(default_factory=Timestamp(datetime.now())) + timestamp: Timestamp + power: Watts + + +@dataclass(frozen=True) +class HomeLoadEnergyInterval(ValueObject): + """ + Value Object for a home load energy consumption interval. + In most cases this can be understood as a 1 hour time range + """ + + start: Timestamp + end: Timestamp + energy: Optional[WattHours] = None + power_points: List[HomeLoadPowerPoint] = field(default_factory=list) + + def __post_init__(self): + """Post-initialization validation.""" + if self.start >= self.end: + raise ValueError("Interval start time must be before end time.") + + for point in self.power_points: + if not (self.start <= point.timestamp <= self.end): + raise ValueError( + f"Power point timestamp {point.timestamp} is outside the interval [{self.start}, {self.end}]." + ) + + @classmethod + def create_from_power_points( + cls, + start: Timestamp, + end: Timestamp, + power_points: List[HomeLoadPowerPoint], + ) -> "HomeLoadEnergyInterval": + """Factory method to create an interval and calculate its energy from power points.""" + total_power = sum(point.power for point in power_points) + avg_power = Watts(total_power / len(power_points)) if power_points else Watts(0.0) + + duration_hours = (end - start).total_seconds() / 3600.0 + calculated_energy = WattHours(avg_power * duration_hours) + + return cls( + start=start, + end=end, + power_points=power_points, + energy=calculated_energy, + ) + + @property + def duration(self) -> timedelta: + """Calculate the duration of the interval""" + return self.end - self.start + + @property + def avg_power(self) -> Watts: + """Calculate the average power over the interval.""" + if not self.power_points: + return Watts(0.0) + + total_power = sum(point.power for point in self.power_points) + return Watts(total_power / len(self.power_points)) if total_power else Watts(0.0) + + +@dataclass(frozen=True) +class LoadEnergyConsumption(ValueObject): + """ + Value Object for a time series of load energy consumption. + Intended to be agnostic: can represent history, forecast, per-device or aggregate. + Intervals are typically 1 hour time ranges. + """ + + timestamp: Timestamp = field(default_factory=Timestamp(datetime.now(timezone.utc))) + intervals: List[HomeLoadEnergyInterval] = field(default_factory=list) + + @property + def total_energy(self) -> WattHours: + """Sum of energy across all intervals.""" + return WattHours(sum(float(i.energy) for i in self.intervals if i.energy is not None)) + + @property + def avg_energy(self) -> WattHours: + """Average of per-interval energy.""" + if not self.intervals: + return WattHours(0.0) + + total_energy = sum(float(interval.energy) for interval in self.intervals if interval.energy) + return WattHours(total_energy / len(self.intervals)) if total_energy else WattHours(0.0) + + @property + def avg_power(self) -> Watts: + """Average of per-interval average power.""" + if not self.intervals: + return Watts(0.0) + + total_power = sum(interval.avg_power for interval in self.intervals) + return Watts(total_power / len(self.intervals)) if total_power else Watts(0.0) + + @property + def peak_power(self) -> Watts: + """Maximum avg_power observed across intervals.""" + if not self.intervals: + return Watts(0.0) + return Watts(max(float(i.avg_power) for i in self.intervals)) + + def in_window(self, start: Timestamp, end: Timestamp) -> "LoadEnergyConsumption": + """Return a subset whose intervals overlap the given window [start, end).""" + if start >= end: + return LoadEnergyConsumption(timestamp=self.timestamp, intervals=[]) + filtered = [i for i in self.intervals if i.start < end and i.end > start] + return LoadEnergyConsumption(timestamp=self.timestamp, intervals=filtered) + + def in_next_hours(self, hours: int, now: Optional[Timestamp] = None) -> "LoadEnergyConsumption": + """Return a subset covering the next `hours` starting from `now` (defaults to datetime.now).""" + anchor = now if now is not None else Timestamp(datetime.now(timezone.utc)) + return self.in_window(anchor, Timestamp(anchor + timedelta(hours=hours))) + + def in_last_hours(self, hours: int, now: Optional[Timestamp] = None) -> "LoadEnergyConsumption": + """Return a subset covering the last `hours` up to `now`.""" + anchor = now if now is not None else Timestamp(datetime.now(timezone.utc)) + return self.in_window(Timestamp(anchor - timedelta(hours=hours)), anchor) + + # Pre-computed window properties for rule engine paths + # e.g. home_load.total_forecast.next_1h.total_energy + + @property + def next_1h(self) -> "LoadEnergyConsumption": + """Subset covering the next 1 hour from now.""" + return self.in_next_hours(1) + + @property + def next_2h(self) -> "LoadEnergyConsumption": + """Subset covering the next 2 hours from now.""" + return self.in_next_hours(2) + + @property + def next_4h(self) -> "LoadEnergyConsumption": + """Subset covering the next 4 hours from now.""" + return self.in_next_hours(4) + + @property + def next_6h(self) -> "LoadEnergyConsumption": + """Subset covering the next 6 hours from now.""" + return self.in_next_hours(6) + + @property + def next_8h(self) -> "LoadEnergyConsumption": + """Subset covering the next 8 hours from now.""" + return self.in_next_hours(8) + + @property + def next_12h(self) -> "LoadEnergyConsumption": + """Subset covering the next 12 hours from now.""" + return self.in_next_hours(12) + + @property + def next_24h(self) -> "LoadEnergyConsumption": + """Subset covering the next 24 hours from now.""" + return self.in_next_hours(24) + + @property + def last_1h(self) -> "LoadEnergyConsumption": + """Subset covering the last 1 hour up to now.""" + return self.in_last_hours(1) + + @property + def last_4h(self) -> "LoadEnergyConsumption": + """Subset covering the last 4 hours up to now.""" + return self.in_last_hours(4) + + @property + def last_12h(self) -> "LoadEnergyConsumption": + """Subset covering the last 12 hours up to now.""" + return self.in_last_hours(12) + + @property + def last_24h(self) -> "LoadEnergyConsumption": + """Subset covering the last 24 hours up to now.""" + return self.in_last_hours(24) + + @staticmethod + def mix( + forecast: "LoadEnergyConsumption", + last_real_power: Watts, + alpha: float = 0.5, + beta: float = 0.5, + ) -> "LoadEnergyConsumption": + """Blend the first forecast interval with the last measured power. + + Implements the mix formula: + + P_mix(k) = α · P̂(k) + β · P_real(k-1) + + Only the **first** interval is blended; the remaining forecast is + returned unchanged. This improves short-term accuracy when the + optimisation loop runs frequently (e.g. every 5 s). + + :param forecast: The original forecast consumption. + :param last_real_power: The most recent measured power value (W). + :param alpha: Weight for the forecast side (default 0.5). + :param beta: Weight for the real-measurement side (default 0.5). + :returns: A new ``LoadEnergyConsumption`` with the blended first interval. + """ + if not forecast.intervals: + return forecast + + first = forecast.intervals[0] + blended_power = Watts(alpha * first.avg_power + beta * float(last_real_power)) + + duration_hours = first.duration.total_seconds() / 3600.0 + blended_energy = WattHours(blended_power * duration_hours) if duration_hours > 0 else first.energy + + blended_interval = HomeLoadEnergyInterval( + start=first.start, + end=first.end, + energy=blended_energy, + power_points=first.power_points, + ) + + new_intervals = [blended_interval] + list(forecast.intervals[1:]) + return LoadEnergyConsumption(timestamp=forecast.timestamp, intervals=new_intervals) + + +@dataclass(frozen=True) +class LoadDeviceConsumption(ValueObject): + """Consumption (history + forecast) for a single LoadDevice. + + Binds the generic ``LoadEnergyConsumption`` time series to the identity + of a LoadDevice so downstream consumers (policy engine, UI) can reason + per-device without losing track of "who is consuming what". + """ + + device_id: EntityId + device_name: str + device_category: LoadDeviceCategory + history: LoadEnergyConsumption = field(default_factory=LoadEnergyConsumption) + forecast: LoadEnergyConsumption = field(default_factory=LoadEnergyConsumption) + + +@dataclass(frozen=True) +class HomeLoadsConsumption(ValueObject): + """Unified household consumption view for the DecisionalContext. + + Carries: + - ``per_device``: individual device history+forecast, keyed by unique name. + - ``total_history`` / ``total_forecast``: aggregated household time series. + + Exposes ``devices`` as a name-indexed mapping for readable rule paths + (e.g., ``home_load.devices.boiler.forecast.total_energy``). + """ + + per_device: List[LoadDeviceConsumption] = field(default_factory=list) + total_history: LoadEnergyConsumption = field(default_factory=LoadEnergyConsumption) + total_forecast: LoadEnergyConsumption = field(default_factory=LoadEnergyConsumption) + + @property + def devices(self) -> "dict[str, LoadDeviceConsumption]": + """Device-name-indexed map for rule engine path navigation. + + Relies on the uniqueness invariant enforced by ``HomeLoadsProfile``. + """ + return {d.device_name: d for d in self.per_device} + + def device_by_name(self, name: str) -> Optional[LoadDeviceConsumption]: + """Lookup by (unique) device name.""" + return self.devices.get(name) + + def device_by_id(self, device_id: EntityId) -> Optional[LoadDeviceConsumption]: + """Lookup by device id.""" + return next((d for d in self.per_device if d.device_id == device_id), None) diff --git a/edge_mining/domain/optimization_unit/aggregate_roots.py b/edge_mining/domain/optimization_unit/aggregate_roots.py index bce4b53..0324aa3 100644 --- a/edge_mining/domain/optimization_unit/aggregate_roots.py +++ b/edge_mining/domain/optimization_unit/aggregate_roots.py @@ -23,9 +23,9 @@ class EnergyOptimizationUnit(AggregateRoot): policy_id: Optional[EntityId] = None # Policy to be used for the optimization target_miner_ids: List[EntityId] = field(default_factory=list) # Miners to be controlled energy_source_id: Optional[EntityId] = None # Energy source to be used + home_loads_profile: Optional[EntityId] = None # Home loads to manage # References to adapters - home_forecast_provider_id: Optional[EntityId] = None # Home load forecast provider to be used performance_tracker_id: Optional[EntityId] = None # Performance tracker to be used notifier_ids: List[EntityId] = field(default_factory=list) # Notifiers to be used @@ -52,9 +52,9 @@ def assign_energy_source(self, energy_source_id: EntityId): """Assign an energy source to the energy optimization unit.""" self.energy_source_id = energy_source_id - def assign_home_forecast_provider(self, home_forecast_provider_id: EntityId): - """Assign a home load forecast provider to the energy optimization unit.""" - self.home_forecast_provider_id = home_forecast_provider_id + def assign_home_loads_profile(self, profile_id: EntityId): + """Assign a home loads profile to the energy optimization unit.""" + self.home_loads_profile = profile_id def assign_performance_tracker(self, performance_tracker_id: EntityId): """Assign a performance tracker to the energy optimization unit.""" diff --git a/edge_mining/domain/policy/value_objects.py b/edge_mining/domain/policy/value_objects.py index 4c84387..4eee9e2 100644 --- a/edge_mining/domain/policy/value_objects.py +++ b/edge_mining/domain/policy/value_objects.py @@ -9,7 +9,7 @@ from edge_mining.domain.energy.value_objects import EnergyStateSnapshot from edge_mining.domain.forecast.aggregate_root import Forecast from edge_mining.domain.forecast.value_objects import Sun -from edge_mining.domain.home_load.value_objects import ConsumptionForecast +from edge_mining.domain.home_load.value_objects import HomeLoadsConsumption from edge_mining.domain.miner.aggregate_roots import Miner from edge_mining.domain.miner.value_objects import MinerStateSnapshot from edge_mining.domain.performance.value_objects import MiningPerformanceSnapshot @@ -23,9 +23,10 @@ class DecisionalContext(ValueObject): energy_state: Optional[EnergyStateSnapshot] forecast: Optional[Forecast] - home_load_forecast: Optional[ConsumptionForecast] - mining_performance: Optional[MiningPerformanceSnapshot] + home_load: Optional[HomeLoadsConsumption] = None + + mining_performance: Optional[MiningPerformanceSnapshot] = None sun: Optional[Sun] = field(default=None) diff --git a/edge_mining/shared/adapter_configs/home_load.py b/edge_mining/shared/adapter_configs/home_load.py index 66402fb..bb0f711 100644 --- a/edge_mining/shared/adapter_configs/home_load.py +++ b/edge_mining/shared/adapter_configs/home_load.py @@ -5,25 +5,194 @@ from dataclasses import asdict, dataclass, field -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter -from edge_mining.shared.interfaces.config import HomeForecastProviderConfig +from edge_mining.domain.home_load.common import ( + EnergyLoadForecastProviderAdapter, + EnergyLoadHistoryProviderAdapter, +) +from edge_mining.shared.interfaces.config import ( + EnergyLoadForecastProviderConfig, + EnergyLoadHistoryProviderConfig, +) @dataclass(frozen=True) -class HomeForecastProviderDummyConfig(HomeForecastProviderConfig): +class EnergyLoadForecastProviderDummyConfig(EnergyLoadForecastProviderConfig): """ - Home Forecast provider configuration. It encapsulate the configuration parameters + Energy Load Forecast provider configuration. It encapsulate the configuration parameters to retrieve home forecast data from a dummy provider. """ load_power_max: float = field(default=500.0) - def is_valid(self, adapter_type: HomeForecastProviderAdapter) -> bool: + def is_valid(self, adapter_type: EnergyLoadForecastProviderAdapter) -> bool: """ Check if the configuration is valid for the given adapter type. - For Dummy Home Forecast, it is always valid. + For Dummy Energy Load Forecast, it is always valid. """ - return adapter_type == HomeForecastProviderAdapter.DUMMY + return adapter_type == EnergyLoadForecastProviderAdapter.DUMMY + + def to_dict(self) -> dict: + """Converts the configuration object into a serializable dictionary""" + return {**asdict(self)} + + @classmethod + def from_dict(cls, data: dict): + """Create a configuration object from a dictionary""" + return cls(**data) + + +@dataclass(frozen=True) +class EnergyLoadForecastProviderNaiveLastHourConfig(EnergyLoadForecastProviderConfig): + """Configuration for NaiveLastHour forecast provider.""" + + hours_ahead: int = field(default=3) + + def is_valid(self, adapter_type: EnergyLoadForecastProviderAdapter) -> bool: + return adapter_type == EnergyLoadForecastProviderAdapter.NAIVE_LAST_HOUR + + def to_dict(self) -> dict: + return {**asdict(self)} + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + +@dataclass(frozen=True) +class EnergyLoadForecastProviderNaivePersistenceConfig(EnergyLoadForecastProviderConfig): + """Configuration for NaivePersistence forecast provider (repeat yesterday's profile).""" + + hours_ahead: int = field(default=24) + delta_days: int = field(default=1) + + def is_valid(self, adapter_type: EnergyLoadForecastProviderAdapter) -> bool: + return adapter_type == EnergyLoadForecastProviderAdapter.NAIVE_PERSISTENCE + + def to_dict(self) -> dict: + return {**asdict(self)} + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + +@dataclass(frozen=True) +class EnergyLoadForecastProviderSeasonalBaselineConfig(EnergyLoadForecastProviderConfig): + """Configuration for SeasonalBaseline forecast provider.""" + + hours_ahead: int = field(default=3) + weeks_lookback: int = field(default=4) + + def is_valid(self, adapter_type: EnergyLoadForecastProviderAdapter) -> bool: + return adapter_type == EnergyLoadForecastProviderAdapter.SEASONAL_BASELINE + + def to_dict(self) -> dict: + return {**asdict(self)} + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + +@dataclass(frozen=True) +class EnergyLoadForecastProviderTypicalProfileConfig(EnergyLoadForecastProviderConfig): + """Configuration for TypicalProfile forecast provider (monthly + weekly + hourly avg).""" + + hours_ahead: int = field(default=24) + weeks_lookback: int = field(default=8) + + def is_valid(self, adapter_type: EnergyLoadForecastProviderAdapter) -> bool: + return adapter_type == EnergyLoadForecastProviderAdapter.TYPICAL_PROFILE + + def to_dict(self) -> dict: + return {**asdict(self)} + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + +@dataclass(frozen=True) +class EnergyLoadForecastProviderSkforecastConfig(EnergyLoadForecastProviderConfig): + """Configuration for skforecast ForecasterRecursive provider. + + ``sklearn_model`` selects the sklearn regressor backend by name, e.g. + ``"RandomForestRegressor"``, ``"Ridge"``, ``"KNeighborsRegressor"`` etc. + """ + + hours_ahead: int = field(default=24) + weeks_lookback: int = field(default=8) + sklearn_model: str = field(default="RandomForestRegressor") + num_lags: int = field(default=72) + + def is_valid(self, adapter_type: EnergyLoadForecastProviderAdapter) -> bool: + return adapter_type == EnergyLoadForecastProviderAdapter.SKFORECAST + + def to_dict(self) -> dict: + return {**asdict(self)} + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + +@dataclass(frozen=True) +class EnergyLoadForecastProviderStatsmodelsConfig(EnergyLoadForecastProviderConfig): + """Configuration for Statsmodels (Holt-Winters / SARIMA) forecast provider.""" + + hours_ahead: int = field(default=3) + weeks_lookback: int = field(default=8) + method: str = field(default="hw") # "hw" (Holt-Winters) or "sarima" + seasonal_periods: int = field(default=24) # hours in a seasonal cycle + + def is_valid(self, adapter_type: EnergyLoadForecastProviderAdapter) -> bool: + return adapter_type == EnergyLoadForecastProviderAdapter.STATSMODELS + + def to_dict(self) -> dict: + return {**asdict(self)} + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + +@dataclass(frozen=True) +class EnergyLoadForecastProviderXGBoostConfig(EnergyLoadForecastProviderConfig): + """Configuration for XGBoost forecast provider.""" + + hours_ahead: int = field(default=3) + weeks_lookback: int = field(default=8) + n_estimators: int = field(default=100) + max_depth: int = field(default=6) + learning_rate: float = field(default=0.1) + + def is_valid(self, adapter_type: EnergyLoadForecastProviderAdapter) -> bool: + return adapter_type == EnergyLoadForecastProviderAdapter.XGBOOST + + def to_dict(self) -> dict: + return {**asdict(self)} + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + +@dataclass(frozen=True) +class EnergyLoadHistoryProviderHomeAssistantAPIConfig(EnergyLoadHistoryProviderConfig): + """ + Energy Load History provider configuration. It encapsulate the configuration parameters + to retrieve historical energy load data from Home Assistant API. + """ + + entity_power: str = field(default="") + unit_power: str = field(default="W") + + def is_valid(self, adapter_type: EnergyLoadHistoryProviderAdapter) -> bool: + """ + Check if the configuration is valid for the given adapter type. + For Home Assistant API, it is always valid. + """ + return adapter_type == EnergyLoadHistoryProviderAdapter.HOME_ASSISTANT_API def to_dict(self) -> dict: """Converts the configuration object into a serializable dictionary""" diff --git a/edge_mining/shared/adapter_maps/home_load.py b/edge_mining/shared/adapter_maps/home_load.py index 3e13c42..af4994e 100644 --- a/edge_mining/shared/adapter_maps/home_load.py +++ b/edge_mining/shared/adapter_maps/home_load.py @@ -5,15 +5,57 @@ from typing import Dict, Optional -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter -from edge_mining.shared.adapter_configs.home_load import HomeForecastProviderDummyConfig +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter, EnergyLoadHistoryProviderAdapter +from edge_mining.shared.adapter_configs.home_load import ( + EnergyLoadForecastProviderDummyConfig, + EnergyLoadForecastProviderNaiveLastHourConfig, + EnergyLoadForecastProviderNaivePersistenceConfig, + EnergyLoadForecastProviderSeasonalBaselineConfig, + EnergyLoadForecastProviderSkforecastConfig, + EnergyLoadForecastProviderStatsmodelsConfig, + EnergyLoadForecastProviderTypicalProfileConfig, + EnergyLoadForecastProviderXGBoostConfig, + EnergyLoadHistoryProviderHomeAssistantAPIConfig, +) from edge_mining.shared.external_services.common import ExternalServiceAdapter -from edge_mining.shared.interfaces.config import HomeForecastProviderConfig +from edge_mining.shared.interfaces.config import EnergyLoadForecastProviderConfig, EnergyLoadHistoryProviderConfig -HOME_FORECAST_PROVIDER_CONFIG_TYPE_MAP: Dict[ - HomeForecastProviderAdapter, Optional[type[HomeForecastProviderConfig]] -] = {HomeForecastProviderAdapter.DUMMY: HomeForecastProviderDummyConfig} +ENERGY_LOAD_FORECAST_PROVIDER_CONFIG_TYPE_MAP: Dict[ + EnergyLoadForecastProviderAdapter, Optional[type[EnergyLoadForecastProviderConfig]] +] = { + EnergyLoadForecastProviderAdapter.DUMMY: EnergyLoadForecastProviderDummyConfig, + EnergyLoadForecastProviderAdapter.NAIVE_LAST_HOUR: EnergyLoadForecastProviderNaiveLastHourConfig, + EnergyLoadForecastProviderAdapter.NAIVE_PERSISTENCE: EnergyLoadForecastProviderNaivePersistenceConfig, + EnergyLoadForecastProviderAdapter.SEASONAL_BASELINE: EnergyLoadForecastProviderSeasonalBaselineConfig, + EnergyLoadForecastProviderAdapter.SKFORECAST: EnergyLoadForecastProviderSkforecastConfig, + EnergyLoadForecastProviderAdapter.STATSMODELS: EnergyLoadForecastProviderStatsmodelsConfig, + EnergyLoadForecastProviderAdapter.TYPICAL_PROFILE: EnergyLoadForecastProviderTypicalProfileConfig, + EnergyLoadForecastProviderAdapter.XGBOOST: EnergyLoadForecastProviderXGBoostConfig, +} + +ENERGY_LOAD_FORECAST_PROVIDER_EXTERNAL_SERVICE_MAP: Dict[ + EnergyLoadForecastProviderAdapter, Optional[ExternalServiceAdapter] +] = { + EnergyLoadForecastProviderAdapter.DUMMY: None, + EnergyLoadForecastProviderAdapter.NAIVE_LAST_HOUR: None, + EnergyLoadForecastProviderAdapter.NAIVE_PERSISTENCE: None, + EnergyLoadForecastProviderAdapter.SEASONAL_BASELINE: None, + EnergyLoadForecastProviderAdapter.SKFORECAST: None, + EnergyLoadForecastProviderAdapter.STATSMODELS: None, + EnergyLoadForecastProviderAdapter.TYPICAL_PROFILE: None, + EnergyLoadForecastProviderAdapter.XGBOOST: None, +} + +ENERGY_LOAD_HISTORY_PROVIDER_CONFIG_TYPE_MAP: Dict[ + EnergyLoadHistoryProviderAdapter, Optional[type[EnergyLoadHistoryProviderConfig]] +] = { + EnergyLoadHistoryProviderAdapter.DUMMY: None, + EnergyLoadHistoryProviderAdapter.HOME_ASSISTANT_API: EnergyLoadHistoryProviderHomeAssistantAPIConfig, +} -HOME_FORECAST_PROVIDER_EXTERNAL_SERVICE_MAP: Dict[HomeForecastProviderAdapter, Optional[ExternalServiceAdapter]] = { - HomeForecastProviderAdapter.DUMMY: None # Dummy does not use an external service +ENERGY_LOAD_HISTORY_PROVIDER_EXTERNAL_SERVICE_MAP: Dict[ + EnergyLoadHistoryProviderAdapter, Optional[ExternalServiceAdapter] +] = { + EnergyLoadHistoryProviderAdapter.DUMMY: None, + EnergyLoadHistoryProviderAdapter.HOME_ASSISTANT_API: ExternalServiceAdapter.HOME_ASSISTANT_API, } diff --git a/edge_mining/shared/external_services/value_objects.py b/edge_mining/shared/external_services/value_objects.py index db601dc..b535005 100644 --- a/edge_mining/shared/external_services/value_objects.py +++ b/edge_mining/shared/external_services/value_objects.py @@ -6,7 +6,7 @@ from edge_mining.domain.common import ValueObject from edge_mining.domain.energy.entities import EnergyMonitor from edge_mining.domain.forecast.entities import ForecastProvider -from edge_mining.domain.home_load.entities import HomeForecastProvider +from edge_mining.domain.home_load.entities import EnergyLoadForecastProvider, EnergyLoadHistoryProvider from edge_mining.domain.miner.entities import MinerController from edge_mining.domain.notification.entities import Notifier @@ -18,5 +18,6 @@ class ExternalServiceLinkedEntities(ValueObject): miner_controllers: List[MinerController] energy_monitors: List[EnergyMonitor] forecast_providers: List[ForecastProvider] - home_forecast_providers: List[HomeForecastProvider] + energy_load_forecast_providers: List[EnergyLoadForecastProvider] + energy_load_history_providers: List[EnergyLoadHistoryProvider] notifiers: List[Notifier] diff --git a/edge_mining/shared/infrastructure.py b/edge_mining/shared/infrastructure.py index cd461e6..5ab2a81 100644 --- a/edge_mining/shared/infrastructure.py +++ b/edge_mining/shared/infrastructure.py @@ -2,12 +2,15 @@ from dataclasses import dataclass from enum import Enum +from typing import Optional from edge_mining.application.interfaces import ( MinerActionServiceInterface, AdapterServiceInterface, ConfigurationServiceInterface, EventBusInterface, + HomeLoadHistoryServiceInterface, + LoadForecastTrainingServiceInterface, OptimizationServiceInterface, ) from edge_mining.domain.energy.ports import ( @@ -16,8 +19,11 @@ ) from edge_mining.domain.forecast.ports import ForecastProviderRepository from edge_mining.domain.home_load.ports import ( - HomeForecastProviderRepository, + EnergyLoadForecastProviderRepository, + EnergyLoadHistoryProviderRepository, + EnergyLoadHistoryRepository, HomeLoadsProfileRepository, + LoadConsumptionModelRepository, ) from edge_mining.domain.miner.ports import MinerControllerRepository, MinerRepository from edge_mining.domain.notification.ports import NotifierRepository @@ -45,7 +51,10 @@ class PersistenceSettings: miner_controller_repo: MinerControllerRepository forecast_provider_repo: ForecastProviderRepository home_profile_repo: HomeLoadsProfileRepository - home_forecast_provider_repo: HomeForecastProviderRepository + energy_load_forecast_provider_repo: EnergyLoadForecastProviderRepository + energy_load_history_provider_repo: EnergyLoadHistoryProviderRepository + home_load_history_repo: EnergyLoadHistoryRepository + load_consumption_model_repo: LoadConsumptionModelRepository policy_repo: OptimizationPolicyRepository mining_performance_tracker_repo: MiningPerformanceTrackerRepository optimization_unit_repo: EnergyOptimizationUnitRepository @@ -62,4 +71,6 @@ class Services: optimization_service: OptimizationServiceInterface miner_action_service: MinerActionServiceInterface configuration_service: ConfigurationServiceInterface + home_load_history_service: HomeLoadHistoryServiceInterface + load_forecast_training_service: Optional[LoadForecastTrainingServiceInterface] event_bus: EventBusInterface diff --git a/edge_mining/shared/interfaces/config.py b/edge_mining/shared/interfaces/config.py index 3ed452a..758790f 100644 --- a/edge_mining/shared/interfaces/config.py +++ b/edge_mining/shared/interfaces/config.py @@ -4,7 +4,7 @@ from edge_mining.domain.energy.common import EnergyMonitorAdapter from edge_mining.domain.forecast.common import ForecastProviderAdapter -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter, EnergyLoadHistoryProviderAdapter from edge_mining.domain.miner.common import MinerControllerAdapter from edge_mining.domain.notification.common import NotificationAdapter from edge_mining.domain.performance.common import MiningPerformanceTrackerAdapter @@ -51,11 +51,20 @@ def is_valid(self, adapter_type: ForecastProviderAdapter) -> bool: pass -class HomeForecastProviderConfig(Configuration): - """Base interface for Home Loads Forecast Provider configurations.""" +class EnergyLoadForecastProviderConfig(Configuration): + """Base interface for Energy Load Forecast Provider configurations.""" @abstractmethod - def is_valid(self, adapter_type: HomeForecastProviderAdapter) -> bool: + def is_valid(self, adapter_type: EnergyLoadForecastProviderAdapter) -> bool: + """Check if the configuration is valid for the given adapter type.""" + pass + + +class EnergyLoadHistoryProviderConfig(Configuration): + """Base interface for Energy Load History Provider configurations.""" + + @abstractmethod + def is_valid(self, adapter_type: EnergyLoadHistoryProviderAdapter) -> bool: """Check if the configuration is valid for the given adapter type.""" pass diff --git a/edge_mining/shared/interfaces/factories.py b/edge_mining/shared/interfaces/factories.py index 8196fc8..fc4a216 100644 --- a/edge_mining/shared/interfaces/factories.py +++ b/edge_mining/shared/interfaces/factories.py @@ -4,6 +4,7 @@ from typing import Any, Optional from edge_mining.domain.energy.entities import EnergySource +from edge_mining.domain.home_load.entities import LoadDevice from edge_mining.domain.miner.aggregate_roots import Miner from edge_mining.shared.external_services.ports import ExternalServicePort from edge_mining.shared.interfaces.config import Configuration, ExternalServiceConfig @@ -64,5 +65,22 @@ def from_energy_source(self, energy_source: EnergySource) -> None: pass +class EnergyLoadForecastAdapterFactory(AdapterFactory): + """Abstract factory for energy load forecast adapters.""" + + +class EnergyLoadHistoryAdapterFactory(AdapterFactory): + """Abstract factory for energy load history adapters (device-scoped).""" + + @abstractmethod + def from_load_device(self, load_device: LoadDevice) -> None: + """Bind the factory to the LoadDevice this adapter will serve. + + Must be called before ``create`` so the resulting adapter knows its + ``device_id`` scope. + """ + pass + + class MiningPerformanceTrackerAdapterFactory(AdapterFactory): """Abstract factory for mining performance tracker adapters""" diff --git a/edge_mining/shared/settings/settings.py b/edge_mining/shared/settings/settings.py index bcc9047..c44fc81 100644 --- a/edge_mining/shared/settings/settings.py +++ b/edge_mining/shared/settings/settings.py @@ -31,6 +31,12 @@ class AppSettings(BaseSettings): # Scheduler settings scheduler_interval_seconds: int = 5 # Evaluate every 5 seconds + history_ingestion_interval_seconds: int = 120 # Collect power points every 2 minutes + history_retention_days: int = 90 # Purge power points older than 90 days + + # Forecast mix settings (α/β blending of forecast with last real measurement) + forecast_mix_alpha: float = 0.5 # Weight for the forecasted value + forecast_mix_beta: float = 0.5 # Weight for the last real measured value model_config = SettingsConfigDict( env_file=".env", # Load .env file if exists diff --git a/pyproject.toml b/pyproject.toml index 7f8480a..02d8db0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,13 @@ solar = [ pyasic = [ "pyasic==0.78.10" ] +ml = [ + "scikit-learn>=1.5.0", + "statsmodels>=0.14.0", + "xgboost>=2.0.0", + "skforecast>=0.14", + "optuna>=3.0", +] all = [ "edge-mining[api,homeassistant,mqtt,telegram,solar,pyasic]", ] diff --git a/requirements.txt b/requirements.txt index 83c87cc..71fe8ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,10 @@ homeassistant_api==4.2.2.post1 python-telegram-bot>=20.0 astral==3.2 pyasic==0.78.10 + +# Optional - For ML Forecast Adapters +scikit-learn>=1.5.0 +statsmodels>=0.14.0 +xgboost>=2.0.0 +skforecast>=0.14 +optuna>=3.0 diff --git a/tests/integration/adapters/persistence/test_sqlalchemy_home_load_repositories.py b/tests/integration/adapters/persistence/test_sqlalchemy_home_load_repositories.py index e8461c5..24e0c73 100644 --- a/tests/integration/adapters/persistence/test_sqlalchemy_home_load_repositories.py +++ b/tests/integration/adapters/persistence/test_sqlalchemy_home_load_repositories.py @@ -2,64 +2,64 @@ import pytest -from edge_mining.adapters.domain.home_load.repositories import SqlAlchemyHomeForecastProviderRepository +from edge_mining.adapters.domain.home_load.repositories import SqlAlchemyEnergyLoadForecastProviderRepository from edge_mining.adapters.infrastructure.persistence.sqlalchemy.base import BaseSQLAlchemyRepository -from edge_mining.domain.home_load.common import HomeForecastProviderAdapter -from edge_mining.domain.home_load.entities import HomeForecastProvider -from edge_mining.shared.adapter_configs.home_load import HomeForecastProviderDummyConfig +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.entities import EnergyLoadForecastProvider +from edge_mining.shared.adapter_configs.home_load import EnergyLoadForecastProviderDummyConfig -class TestSqlAlchemyHomeForecastProviderRepository: - """Integration tests for SqlAlchemyHomeForecastProviderRepository.""" +class TestSqlAlchemyEnergyLoadForecastProviderRepository: + """Integration tests for SqlAlchemyEnergyLoadForecastProviderRepository.""" @pytest.fixture - def repository(self, sqlalchemy_repo: BaseSQLAlchemyRepository) -> SqlAlchemyHomeForecastProviderRepository: - """Create a HomeForecastProvider repository instance.""" - return SqlAlchemyHomeForecastProviderRepository(db=sqlalchemy_repo) + def repository(self, sqlalchemy_repo: BaseSQLAlchemyRepository) -> SqlAlchemyEnergyLoadForecastProviderRepository: + """Create a EnergyLoadForecastProvider repository instance.""" + return SqlAlchemyEnergyLoadForecastProviderRepository(db=sqlalchemy_repo) - def test_add_and_get_home_forecast_provider_with_enum_adapter( - self, repository: SqlAlchemyHomeForecastProviderRepository + def test_add_and_get_energy_load_forecast_provider_with_enum_adapter( + self, repository: SqlAlchemyEnergyLoadForecastProviderRepository ): """Regression test: enum adapter_type must persist without sqlite binding errors.""" - provider = HomeForecastProvider( + provider = EnergyLoadForecastProvider( name="Home Forecast Test", - adapter_type=HomeForecastProviderAdapter.DUMMY, - config=HomeForecastProviderDummyConfig(load_power_max=650.0), + adapter_type=EnergyLoadForecastProviderAdapter.DUMMY, + config=EnergyLoadForecastProviderDummyConfig(load_power_max=650.0), ) original_id = provider.id repository.add(provider) - assert provider.adapter_type == HomeForecastProviderAdapter.DUMMY + assert provider.adapter_type == EnergyLoadForecastProviderAdapter.DUMMY retrieved = repository.get_by_id(original_id) assert retrieved is not None assert retrieved.id == original_id assert retrieved.name == "Home Forecast Test" - assert retrieved.adapter_type == HomeForecastProviderAdapter.DUMMY - assert isinstance(retrieved.config, HomeForecastProviderDummyConfig) + assert retrieved.adapter_type == EnergyLoadForecastProviderAdapter.DUMMY + assert isinstance(retrieved.config, EnergyLoadForecastProviderDummyConfig) - def test_update_home_forecast_provider_with_enum_adapter( - self, repository: SqlAlchemyHomeForecastProviderRepository + def test_update_energy_load_forecast_provider_with_enum_adapter( + self, repository: SqlAlchemyEnergyLoadForecastProviderRepository ): """Regression test: enum adapter_type must remain valid through update commit.""" - provider = HomeForecastProvider( + provider = EnergyLoadForecastProvider( name="Original Home Forecast", - adapter_type=HomeForecastProviderAdapter.DUMMY, - config=HomeForecastProviderDummyConfig(load_power_max=400.0), + adapter_type=EnergyLoadForecastProviderAdapter.DUMMY, + config=EnergyLoadForecastProviderDummyConfig(load_power_max=400.0), ) repository.add(provider) provider.name = "Updated Home Forecast" - provider.adapter_type = HomeForecastProviderAdapter.DUMMY + provider.adapter_type = EnergyLoadForecastProviderAdapter.DUMMY repository.update(provider) - assert provider.adapter_type == HomeForecastProviderAdapter.DUMMY + assert provider.adapter_type == EnergyLoadForecastProviderAdapter.DUMMY retrieved = repository.get_by_id(provider.id) assert retrieved is not None assert retrieved.name == "Updated Home Forecast" - assert retrieved.adapter_type == HomeForecastProviderAdapter.DUMMY + assert retrieved.adapter_type == EnergyLoadForecastProviderAdapter.DUMMY if __name__ == "__main__": diff --git a/tests/unit/adapters/domain/home_load/__init__.py b/tests/unit/adapters/domain/home_load/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/adapters/domain/home_load/test_home_load_api_endpoints.py b/tests/unit/adapters/domain/home_load/test_home_load_api_endpoints.py new file mode 100644 index 0000000..bd9a181 --- /dev/null +++ b/tests/unit/adapters/domain/home_load/test_home_load_api_endpoints.py @@ -0,0 +1,235 @@ +"""Unit tests for home load API endpoints: device history, training trigger, models list.""" + +import uuid +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from edge_mining.adapters.domain.home_load.fast_api.router import router +from edge_mining.adapters.infrastructure.api.setup import ( + get_config_service, + get_home_load_history_service, + get_load_forecast_training_service, +) +from edge_mining.domain.common import EntityId, Timestamp, Watts +from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.entities import LoadConsumptionModel, LoadDevice +from edge_mining.domain.home_load.value_objects import HomeLoadPowerPoint + + +# --- Fixtures --- + + +@pytest.fixture +def device_id() -> EntityId: + return EntityId(uuid.uuid4()) + + +@pytest.fixture +def profile_id() -> EntityId: + return EntityId(uuid.uuid4()) + + +@pytest.fixture +def profile_with_device(profile_id, device_id) -> HomeLoadsProfile: + device = LoadDevice(id=device_id, name="Dishwasher", enabled=True) + return HomeLoadsProfile(id=profile_id, name="Test Home", devices=[device]) + + +@pytest.fixture +def mock_config_service(profile_with_device): + svc = MagicMock() + svc.get_home_loads_profile.return_value = profile_with_device + return svc + + +@pytest.fixture +def mock_history_service(): + return MagicMock() + + +@pytest.fixture +def mock_training_service(): + svc = AsyncMock() + svc.get_models = MagicMock(return_value=[]) + return svc + + +@pytest.fixture +def client(mock_config_service, mock_history_service, mock_training_service): + app = FastAPI() + app.include_router(router, prefix="/api/v1") + + app.dependency_overrides[get_config_service] = lambda: mock_config_service + app.dependency_overrides[get_home_load_history_service] = lambda: mock_history_service + app.dependency_overrides[get_load_forecast_training_service] = lambda: mock_training_service + + return TestClient(app) + + +# --- Device History Endpoint Tests --- + + +class TestGetDeviceHistory: + def test_returns_power_points(self, client, mock_history_service, profile_id, device_id): + now = datetime.now() + points = [ + HomeLoadPowerPoint(timestamp=Timestamp(now - timedelta(hours=1)), power=Watts(100.0)), + HomeLoadPowerPoint(timestamp=Timestamp(now), power=Watts(200.0)), + ] + mock_history_service.get_device_history.return_value = points + + start = (now - timedelta(hours=2)).isoformat() + end = now.isoformat() + response = client.get( + f"/api/v1/home-loads-profiles/{profile_id}/devices/{device_id}/history", + params={"start": start, "end": end}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["power"] == 100.0 + assert data[1]["power"] == 200.0 + + def test_returns_empty_list(self, client, mock_history_service, profile_id, device_id): + mock_history_service.get_device_history.return_value = [] + now = datetime.now() + + response = client.get( + f"/api/v1/home-loads-profiles/{profile_id}/devices/{device_id}/history", + params={"start": (now - timedelta(hours=1)).isoformat(), "end": now.isoformat()}, + ) + + assert response.status_code == 200 + assert response.json() == [] + + def test_profile_not_found(self, client, mock_config_service): + mock_config_service.get_home_loads_profile.return_value = None + unknown_profile = uuid.uuid4() + device = uuid.uuid4() + now = datetime.now() + + response = client.get( + f"/api/v1/home-loads-profiles/{unknown_profile}/devices/{device}/history", + params={"start": (now - timedelta(hours=1)).isoformat(), "end": now.isoformat()}, + ) + + assert response.status_code == 404 + + def test_device_not_found(self, client, profile_id): + unknown_device = uuid.uuid4() + now = datetime.now() + + response = client.get( + f"/api/v1/home-loads-profiles/{profile_id}/devices/{unknown_device}/history", + params={"start": (now - timedelta(hours=1)).isoformat(), "end": now.isoformat()}, + ) + + assert response.status_code == 404 + + +# --- Training Trigger Endpoint Tests --- + + +class TestTriggerTrainingAll: + def test_trigger_training_all_success(self, client, mock_training_service): + mock_training_service.train_all = AsyncMock() + + response = client.post("/api/v1/training/trigger") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" + + def test_trigger_training_all_with_weeks_lookback(self, client, mock_training_service): + mock_training_service.train_all = AsyncMock() + + response = client.post("/api/v1/training/trigger", params={"weeks_lookback": 4}) + + assert response.status_code == 200 + + +class TestTriggerTrainingDevice: + def test_trigger_device_training_success(self, client, mock_training_service, profile_id, device_id): + mock_training_service.train_device = AsyncMock() + + response = client.post( + f"/api/v1/home-loads-profiles/{profile_id}/devices/{device_id}/training/trigger", + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" + assert "Dishwasher" in data["detail"] + + def test_trigger_device_training_profile_not_found(self, client, mock_config_service): + mock_config_service.get_home_loads_profile.return_value = None + unknown = uuid.uuid4() + device = uuid.uuid4() + + response = client.post( + f"/api/v1/home-loads-profiles/{unknown}/devices/{device}/training/trigger", + ) + + assert response.status_code == 404 + + def test_trigger_device_training_device_not_found(self, client, profile_id): + unknown_device = uuid.uuid4() + + response = client.post( + f"/api/v1/home-loads-profiles/{profile_id}/devices/{unknown_device}/training/trigger", + ) + + assert response.status_code == 404 + + +# --- Training Models List Endpoint Tests --- + + +class TestGetTrainingModels: + def test_list_models_empty(self, client, mock_training_service): + mock_training_service.get_models.return_value = [] + + response = client.get("/api/v1/training/models") + + assert response.status_code == 200 + assert response.json() == [] + + def test_list_models_returns_data(self, client, mock_training_service, device_id): + model = LoadConsumptionModel( + device_id=device_id, + adapter_type=EnergyLoadForecastProviderAdapter.STATSMODELS, + trained_at=datetime.now(), + mae=1.5, + rmse=2.0, + samples_used=100, + is_active=True, + ) + mock_training_service.get_models.return_value = [model] + + response = client.get("/api/v1/training/models") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["mae"] == 1.5 + assert data[0]["is_active"] is True + assert data[0]["device_id"] == str(device_id) + + def test_list_models_filtered_by_device(self, client, mock_training_service, device_id): + mock_training_service.get_models.return_value = [] + + response = client.get("/api/v1/training/models", params={"device_id": str(device_id)}) + + assert response.status_code == 200 + mock_training_service.get_models.assert_called_once() + + def test_list_models_invalid_device_id(self, client): + response = client.get("/api/v1/training/models", params={"device_id": "not-a-uuid"}) + + assert response.status_code == 400 diff --git a/tests/unit/adapters/home_load/test_backtesting.py b/tests/unit/adapters/home_load/test_backtesting.py new file mode 100644 index 0000000..930675e --- /dev/null +++ b/tests/unit/adapters/home_load/test_backtesting.py @@ -0,0 +1,171 @@ +"""Unit tests for F7 — Rolling-window backtesting integration. + +Tests the ``backtest()`` static method on ``SkforecastForecastProvider``, +the new ``backtest_mae / backtest_rmse / backtest_folds`` fields on +``LoadConsumptionModel``, and the corresponding schema fields. +""" + +import pytest + +from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + _SKFORECAST_AVAILABLE, +) +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.entities import LoadConsumptionModel + +pytestmark = pytest.mark.skipif(not _SKFORECAST_AVAILABLE, reason="skforecast not installed") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_series(hours: int = 300): + """Create a pandas Series of synthetic hourly power values.""" + import pandas as pd + + values = [300.0 + (i % 24) * 10 + (i % 7) * 5 for i in range(hours)] + return pd.Series(values, name="power") + + +def _fit_forecaster(y, lags: int = 24): + """Return a fitted ForecasterRecursive on *y*.""" + from skforecast.recursive import ForecasterRecursive + from sklearn.linear_model import Ridge + + forecaster = ForecasterRecursive(estimator=Ridge(), lags=lags) + forecaster.fit(y=y) + return forecaster + + +# --------------------------------------------------------------------------- +# backtest() static method tests +# --------------------------------------------------------------------------- + +class TestSkforecastBacktest: + """Tests for SkforecastForecastProvider.backtest().""" + + def test_backtest_returns_dict_with_expected_keys(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + ) + + y = _make_series(300) + forecaster = _fit_forecaster(y) + result = SkforecastForecastProvider.backtest( + forecaster=forecaster, + y_series=y, + steps=24, + folds=3, + ) + assert isinstance(result, dict) + assert "backtest_mae" in result + assert "backtest_rmse" in result + assert "backtest_folds" in result + + def test_backtest_mae_is_positive(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + ) + + y = _make_series(300) + forecaster = _fit_forecaster(y) + result = SkforecastForecastProvider.backtest( + forecaster=forecaster, + y_series=y, + steps=24, + folds=3, + ) + assert result["backtest_mae"] is not None + assert result["backtest_mae"] >= 0 + + def test_backtest_rmse_is_positive(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + ) + + y = _make_series(300) + forecaster = _fit_forecaster(y) + result = SkforecastForecastProvider.backtest( + forecaster=forecaster, + y_series=y, + steps=24, + folds=3, + ) + assert result["backtest_rmse"] is not None + assert result["backtest_rmse"] >= 0 + + def test_backtest_folds_positive(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + ) + + y = _make_series(300) + forecaster = _fit_forecaster(y) + result = SkforecastForecastProvider.backtest( + forecaster=forecaster, + y_series=y, + steps=24, + folds=3, + ) + assert result["backtest_folds"] > 0 + + def test_backtest_too_short_series_returns_zeros(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + ) + + # Very short series — not enough for even 2*steps training + y = _make_series(30) + forecaster = _fit_forecaster(y, lags=6) + result = SkforecastForecastProvider.backtest( + forecaster=forecaster, + y_series=y, + steps=24, + folds=3, + ) + assert result["backtest_mae"] is None + assert result["backtest_folds"] == 0 + + +# --------------------------------------------------------------------------- +# LoadConsumptionModel backtest fields tests +# --------------------------------------------------------------------------- + +class TestLoadConsumptionModelBacktestFields: + """Tests for backtest_mae/rmse/folds fields on the entity.""" + + def test_defaults(self): + model = LoadConsumptionModel() + assert model.backtest_mae is None + assert model.backtest_rmse is None + assert model.backtest_folds == 0 + + def test_set_values(self): + model = LoadConsumptionModel(backtest_mae=12.5, backtest_rmse=15.3, backtest_folds=5) + assert model.backtest_mae == 12.5 + assert model.backtest_rmse == 15.3 + assert model.backtest_folds == 5 + + def test_schema_includes_backtest_fields(self): + from edge_mining.adapters.domain.home_load.schemas import LoadConsumptionModelSchema + + model = LoadConsumptionModel( + adapter_type=EnergyLoadForecastProviderAdapter.SKFORECAST, + backtest_mae=8.2, + backtest_rmse=10.1, + backtest_folds=4, + ) + schema = LoadConsumptionModelSchema.from_model(model) + assert schema.backtest_mae == 8.2 + assert schema.backtest_rmse == 10.1 + assert schema.backtest_folds == 4 + + def test_schema_backtest_defaults(self): + from edge_mining.adapters.domain.home_load.schemas import LoadConsumptionModelSchema + + model = LoadConsumptionModel() + schema = LoadConsumptionModelSchema.from_model(model) + assert schema.backtest_mae is None + assert schema.backtest_rmse is None + assert schema.backtest_folds == 0 diff --git a/tests/unit/adapters/home_load/test_naive_persistence_forecast_provider.py b/tests/unit/adapters/home_load/test_naive_persistence_forecast_provider.py new file mode 100644 index 0000000..84a89d4 --- /dev/null +++ b/tests/unit/adapters/home_load/test_naive_persistence_forecast_provider.py @@ -0,0 +1,213 @@ +"""Unit tests for NaivePersistence forecast provider.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from edge_mining.adapters.domain.home_load.forecast_providers.naive_persistence import ( + NaivePersistenceForecastProvider, + NaivePersistenceForecastProviderFactory, +) +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import ( + EnergyLoadForecastProviderNaivePersistenceConfig, + EnergyLoadForecastProviderDummyConfig, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_history(hours: int = 48, base_power: float = 300.0) -> LoadEnergyConsumption: + """Build a synthetic hourly history going back ``hours`` hours from now. + + Power follows a simple pattern based on hour-of-day to make assertions + deterministic: ``base_power + hour_of_day * 10``. + """ + now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) + intervals = [] + for i in range(hours, 0, -1): + start = Timestamp(now - timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(base_power + start.hour * 10) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[HomeLoadPowerPoint(timestamp=start, power=power)], + energy=WattHours(float(power)), + ) + ) + return LoadEnergyConsumption(timestamp=Timestamp(now), intervals=intervals) + + +# --------------------------------------------------------------------------- +# Factory tests +# --------------------------------------------------------------------------- + +class TestNaivePersistenceForecastProviderFactory: + """Tests for the factory.""" + + def test_create_with_default_config(self): + factory = NaivePersistenceForecastProviderFactory() + provider = factory.create(config=None, logger=None, external_service=None) + assert isinstance(provider, NaivePersistenceForecastProvider) + + def test_create_with_valid_config(self): + config = EnergyLoadForecastProviderNaivePersistenceConfig(hours_ahead=12, delta_days=2) + factory = NaivePersistenceForecastProviderFactory() + provider = factory.create(config=config, logger=None, external_service=None) + assert isinstance(provider, NaivePersistenceForecastProvider) + assert provider._hours_ahead == 12 + assert provider._delta_days == 2 + + def test_create_with_wrong_config_type_raises(self): + config = EnergyLoadForecastProviderDummyConfig() + factory = NaivePersistenceForecastProviderFactory() + with pytest.raises(EnergyLoadForecastProviderError): + factory.create(config=config, logger=None, external_service=None) + + +# --------------------------------------------------------------------------- +# Provider tests +# --------------------------------------------------------------------------- + +class TestNaivePersistenceForecastProvider: + """Tests for the provider.""" + + def test_adapter_type(self): + provider = NaivePersistenceForecastProvider() + assert provider.forecast_provider_type == EnergyLoadForecastProviderAdapter.NAIVE_PERSISTENCE + + def test_min_required_history_hours_default(self): + provider = NaivePersistenceForecastProvider(delta_days=1) + assert provider.min_required_history_hours == 24 + + def test_min_required_history_hours_custom(self): + provider = NaivePersistenceForecastProvider(delta_days=3) + assert provider.min_required_history_hours == 72 + + def test_returns_none_for_empty_history(self): + provider = NaivePersistenceForecastProvider(hours_ahead=3) + empty = LoadEnergyConsumption(timestamp=Timestamp(datetime.now(timezone.utc)), intervals=[]) + assert provider.get_consumption_forecast(empty) is None + + def test_returns_none_for_zero_hours(self): + provider = NaivePersistenceForecastProvider(hours_ahead=0) + history = _make_history(48) + assert provider.get_consumption_forecast(history) is None + + def test_forecast_length_matches_hours_ahead(self): + provider = NaivePersistenceForecastProvider(hours_ahead=6) + history = _make_history(48) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 6 + + def test_forecast_default_24h(self): + provider = NaivePersistenceForecastProvider(hours_ahead=24) + history = _make_history(48) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 24 + + def test_forecast_uses_yesterday_profile(self): + """Each forecast hour should match the power from the same hour yesterday.""" + provider = NaivePersistenceForecastProvider(hours_ahead=6, delta_days=1) + history = _make_history(48, base_power=300.0) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + + for interval in forecast.intervals: + expected_power = 300.0 + interval.start.hour * 10 + assert float(interval.avg_power) == pytest.approx(expected_power, abs=1.0) + + def test_forecast_delta_days_2(self): + """With delta_days=2, power should come from 2 days ago.""" + provider = NaivePersistenceForecastProvider(hours_ahead=3, delta_days=2) + history = _make_history(72, base_power=200.0) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 3 + + def test_forecast_intervals_are_contiguous(self): + provider = NaivePersistenceForecastProvider(hours_ahead=4) + history = _make_history(48) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + for i in range(len(forecast.intervals) - 1): + assert forecast.intervals[i].end == forecast.intervals[i + 1].start + + def test_forecast_power_non_negative(self): + provider = NaivePersistenceForecastProvider(hours_ahead=6) + history = _make_history(48, base_power=0.0) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + for interval in forecast.intervals: + assert float(interval.avg_power) >= 0.0 + + def test_fallback_to_avg_when_reference_missing(self): + """When reference day has gaps, fallback to history average.""" + now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) + # Only 2 hours of history — not enough for a full reference day + intervals = [] + for i in [2, 1]: + start = Timestamp(now - timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(500.0) + intervals.append( + HomeLoadEnergyInterval( + start=start, end=end, + power_points=[HomeLoadPowerPoint(timestamp=start, power=power)], + energy=WattHours(500.0), + ) + ) + sparse_history = LoadEnergyConsumption(timestamp=Timestamp(now), intervals=intervals) + + provider = NaivePersistenceForecastProvider(hours_ahead=3, delta_days=1) + forecast = provider.get_consumption_forecast(sparse_history) + assert forecast is not None + # Should still produce 3 intervals, falling back to avg_power + assert len(forecast.intervals) == 3 + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + +class TestNaivePersistenceConfig: + """Tests for the config dataclass.""" + + def test_defaults(self): + config = EnergyLoadForecastProviderNaivePersistenceConfig() + assert config.hours_ahead == 24 + assert config.delta_days == 1 + + def test_custom_values(self): + config = EnergyLoadForecastProviderNaivePersistenceConfig(hours_ahead=12, delta_days=3) + assert config.hours_ahead == 12 + assert config.delta_days == 3 + + def test_is_valid(self): + config = EnergyLoadForecastProviderNaivePersistenceConfig() + assert config.is_valid(EnergyLoadForecastProviderAdapter.NAIVE_PERSISTENCE) is True + assert config.is_valid(EnergyLoadForecastProviderAdapter.DUMMY) is False + + def test_to_dict_from_dict_roundtrip(self): + config = EnergyLoadForecastProviderNaivePersistenceConfig(hours_ahead=8, delta_days=2) + d = config.to_dict() + restored = EnergyLoadForecastProviderNaivePersistenceConfig.from_dict(d) + assert restored == config + + def test_frozen(self): + config = EnergyLoadForecastProviderNaivePersistenceConfig() + with pytest.raises(AttributeError): + config.hours_ahead = 10 # type: ignore[misc] diff --git a/tests/unit/adapters/home_load/test_optuna_tuning.py b/tests/unit/adapters/home_load/test_optuna_tuning.py new file mode 100644 index 0000000..ec5e032 --- /dev/null +++ b/tests/unit/adapters/home_load/test_optuna_tuning.py @@ -0,0 +1,180 @@ +"""Unit tests for Optuna Bayesian tuning integration (F6). + +Tests the ``tune()`` static method on ``SkforecastForecastProvider``, +the ``_build_search_space`` helper, and the ``tuning_params`` field on +``LoadConsumptionModel``. +""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + _SKFORECAST_AVAILABLE, +) +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.entities import LoadConsumptionModel +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) + +pytestmark = pytest.mark.skipif(not _SKFORECAST_AVAILABLE, reason="skforecast not installed") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_series(hours: int = 300): + """Create a pandas Series of synthetic hourly power values.""" + import pandas as pd + + values = [300.0 + (i % 24) * 10 + (i % 7) * 5 for i in range(hours)] + return pd.Series(values, name="power") + + +def _make_history(hours: int = 300, base_power: float = 300.0) -> LoadEnergyConsumption: + """Build hourly LoadEnergyConsumption for training service tests.""" + now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) + intervals = [] + for i in range(hours, 0, -1): + start = Timestamp(now - timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(base_power + start.hour * 10 + (i % 7) * 5) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[HomeLoadPowerPoint(timestamp=start, power=power)], + energy=WattHours(float(power)), + ) + ) + return LoadEnergyConsumption(timestamp=Timestamp(now), intervals=intervals) + + +# --------------------------------------------------------------------------- +# tune() static method tests +# --------------------------------------------------------------------------- + +class TestSkforecastTune: + """Tests for SkforecastForecastProvider.tune().""" + + def test_tune_returns_params_and_forecaster(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + ) + + y = _make_series(300) + best_params, tuned_forecaster = SkforecastForecastProvider.tune( + y_series=y, + sklearn_model_name="Ridge", + num_lags=24, + steps=24, + n_trials=3, # small for speed + ) + assert isinstance(best_params, dict) + assert tuned_forecaster is not None + + def test_tune_with_random_forest(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + ) + + y = _make_series(300) + best_params, tuned_forecaster = SkforecastForecastProvider.tune( + y_series=y, + sklearn_model_name="RandomForestRegressor", + num_lags=24, + steps=24, + n_trials=3, + ) + assert isinstance(best_params, dict) + # Tuned forecaster should be able to predict + preds = tuned_forecaster.predict(steps=6) + assert len(preds) == 6 + + def test_tune_with_kneighbors(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + ) + + y = _make_series(300) + best_params, tuned = SkforecastForecastProvider.tune( + y_series=y, + sklearn_model_name="KNeighborsRegressor", + num_lags=24, + steps=24, + n_trials=3, + ) + assert isinstance(best_params, dict) + + +# --------------------------------------------------------------------------- +# _build_search_space tests +# --------------------------------------------------------------------------- + +class TestBuildSearchSpace: + """Tests for the search space builder.""" + + def test_rf_space_callable(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + _build_search_space, + ) + + space = _build_search_space("RandomForestRegressor") + assert callable(space) + + def test_ridge_space_callable(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + _build_search_space, + ) + + space = _build_search_space("Ridge") + assert callable(space) + + def test_unknown_model_returns_default_space(self): + from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + _build_search_space, + ) + + space = _build_search_space("SomeUnknownModel") + assert callable(space) + + +# --------------------------------------------------------------------------- +# LoadConsumptionModel.tuning_params tests +# --------------------------------------------------------------------------- + +class TestLoadConsumptionModelTuningParams: + """Tests for the tuning_params field on the entity.""" + + def test_default_is_none(self): + model = LoadConsumptionModel() + assert model.tuning_params is None + + def test_can_set_dict(self): + params = {"n_estimators": 200, "max_depth": 10, "lags": 48} + model = LoadConsumptionModel(tuning_params=params) + assert model.tuning_params == params + assert model.tuning_params["n_estimators"] == 200 + + def test_schema_includes_tuning_params(self): + from edge_mining.adapters.domain.home_load.schemas import LoadConsumptionModelSchema + + params = {"alpha": 0.5, "lags": 24} + model = LoadConsumptionModel( + adapter_type=EnergyLoadForecastProviderAdapter.SKFORECAST, + tuning_params=params, + ) + schema = LoadConsumptionModelSchema.from_model(model) + assert schema.tuning_params == params + + def test_schema_tuning_params_none(self): + from edge_mining.adapters.domain.home_load.schemas import LoadConsumptionModelSchema + + model = LoadConsumptionModel() + schema = LoadConsumptionModelSchema.from_model(model) + assert schema.tuning_params is None diff --git a/tests/unit/adapters/home_load/test_skforecast_forecast_provider.py b/tests/unit/adapters/home_load/test_skforecast_forecast_provider.py new file mode 100644 index 0000000..2f594ce --- /dev/null +++ b/tests/unit/adapters/home_load/test_skforecast_forecast_provider.py @@ -0,0 +1,267 @@ +"""Unit tests for Skforecast forecast provider.""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock + +import pytest + +from edge_mining.adapters.domain.home_load.forecast_providers.skforecast_provider import ( + SkforecastForecastProvider, + SkforecastForecastProviderFactory, + _resolve_sklearn_model, + _SKFORECAST_AVAILABLE, +) +from edge_mining.domain.common import EntityId, Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import ( + EnergyLoadForecastProviderDummyConfig, + EnergyLoadForecastProviderSkforecastConfig, +) + +pytestmark = pytest.mark.skipif(not _SKFORECAST_AVAILABLE, reason="skforecast not installed") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_history(hours: int = 200, base_power: float = 300.0) -> LoadEnergyConsumption: + """Build a synthetic hourly history. + + Power pattern: ``base_power + hour_of_day * 10 + sin-like wobble``. + Needs to be long enough for num_lags + forecast horizon. + """ + now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) + intervals = [] + for i in range(hours, 0, -1): + start = Timestamp(now - timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(base_power + start.hour * 10 + (i % 7) * 5) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[HomeLoadPowerPoint(timestamp=start, power=power)], + energy=WattHours(float(power)), + ) + ) + return LoadEnergyConsumption(timestamp=Timestamp(now), intervals=intervals) + + +# --------------------------------------------------------------------------- +# sklearn resolver tests +# --------------------------------------------------------------------------- + +class TestSklearnModelResolver: + """Tests for _resolve_sklearn_model.""" + + def test_random_forest(self): + model = _resolve_sklearn_model("RandomForestRegressor") + assert model.__class__.__name__ == "RandomForestRegressor" + + def test_ridge(self): + model = _resolve_sklearn_model("Ridge") + assert model.__class__.__name__ == "Ridge" + + def test_kneighbors(self): + model = _resolve_sklearn_model("KNeighborsRegressor") + assert model.__class__.__name__ == "KNeighborsRegressor" + + def test_gradient_boosting(self): + model = _resolve_sklearn_model("GradientBoostingRegressor") + assert model.__class__.__name__ == "GradientBoostingRegressor" + + def test_unsupported_model_raises(self): + with pytest.raises(EnergyLoadForecastProviderError, match="Unsupported sklearn model"): + _resolve_sklearn_model("FakeModel") + + +# --------------------------------------------------------------------------- +# Factory tests +# --------------------------------------------------------------------------- + +class TestSkforecastForecastProviderFactory: + """Tests for the factory.""" + + def test_create_with_default_config(self): + factory = SkforecastForecastProviderFactory() + provider = factory.create(config=None, logger=None, external_service=None) + assert isinstance(provider, SkforecastForecastProvider) + + def test_create_with_valid_config(self): + config = EnergyLoadForecastProviderSkforecastConfig( + hours_ahead=12, weeks_lookback=4, sklearn_model="Ridge", num_lags=48 + ) + factory = SkforecastForecastProviderFactory() + provider = factory.create(config=config, logger=None, external_service=None) + assert isinstance(provider, SkforecastForecastProvider) + assert provider._hours_ahead == 12 + assert provider._sklearn_model == "Ridge" + assert provider._num_lags == 48 + + def test_create_with_model_repo(self): + mock_repo = MagicMock() + factory = SkforecastForecastProviderFactory(model_repo=mock_repo) + provider = factory.create(config=None, logger=None, external_service=None) + assert provider._model_repo is mock_repo + + def test_create_with_wrong_config_type_raises(self): + config = EnergyLoadForecastProviderDummyConfig() + factory = SkforecastForecastProviderFactory() + with pytest.raises(EnergyLoadForecastProviderError): + factory.create(config=config, logger=None, external_service=None) + + +# --------------------------------------------------------------------------- +# Provider tests +# --------------------------------------------------------------------------- + +class TestSkforecastForecastProvider: + """Tests for the provider.""" + + def test_adapter_type(self): + provider = SkforecastForecastProvider() + assert provider.forecast_provider_type == EnergyLoadForecastProviderAdapter.SKFORECAST + + def test_min_required_history_hours(self): + provider = SkforecastForecastProvider(num_lags=72, hours_ahead=24) + assert provider.min_required_history_hours == 72 + 48 + 24 + + def test_returns_none_for_empty_history(self): + provider = SkforecastForecastProvider(hours_ahead=6, num_lags=24) + empty = LoadEnergyConsumption(timestamp=Timestamp(datetime.now(timezone.utc)), intervals=[]) + assert provider.get_consumption_forecast(empty) is None + + def test_returns_none_for_zero_hours(self): + provider = SkforecastForecastProvider(hours_ahead=0) + history = _make_history(200) + assert provider.get_consumption_forecast(history) is None + + def test_returns_none_for_insufficient_history(self): + provider = SkforecastForecastProvider(hours_ahead=24, num_lags=72) + short_history = _make_history(50) # < 72 + 24 = 96 needed + assert provider.get_consumption_forecast(short_history) is None + + def test_forecast_length_matches_hours_ahead(self): + provider = SkforecastForecastProvider(hours_ahead=6, num_lags=24) + history = _make_history(200) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 6 + + def test_forecast_with_random_forest(self): + provider = SkforecastForecastProvider( + hours_ahead=12, num_lags=24, sklearn_model="RandomForestRegressor" + ) + history = _make_history(200) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 12 + + def test_forecast_with_ridge(self): + provider = SkforecastForecastProvider( + hours_ahead=6, num_lags=24, sklearn_model="Ridge" + ) + history = _make_history(200) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 6 + + def test_forecast_intervals_are_contiguous(self): + provider = SkforecastForecastProvider(hours_ahead=8, num_lags=24) + history = _make_history(200) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + for i in range(len(forecast.intervals) - 1): + assert forecast.intervals[i].end == forecast.intervals[i + 1].start + + def test_forecast_power_non_negative(self): + provider = SkforecastForecastProvider(hours_ahead=6, num_lags=24) + history = _make_history(200) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + for interval in forecast.intervals: + assert float(interval.avg_power) >= 0.0 + + def test_saved_model_used_when_available(self): + """If model_repo returns a saved model, it should be used.""" + import pickle + + import pandas as pd + from skforecast.recursive import ForecasterRecursive + from sklearn.linear_model import Ridge + + # Train a small model + history = _make_history(200) + from edge_mining.adapters.domain.home_load.forecast_providers.features import ( + fill_missing_hours, + intervals_to_hourly_series, + ) + + series = intervals_to_hourly_series(history) + series = fill_missing_hours(series) + powers = [p for _, p in series] + y = pd.Series(powers, name="power") + forecaster = ForecasterRecursive(estimator=Ridge(), lags=24) + forecaster.fit(y=y) + model_bytes = pickle.dumps(forecaster) + + # Mock model_repo + mock_model = MagicMock() + mock_model.model_bytes = model_bytes + mock_repo = MagicMock() + mock_repo.get_active_model.return_value = mock_model + + provider = SkforecastForecastProvider( + hours_ahead=6, num_lags=24, model_repo=mock_repo + ) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 6 + mock_repo.get_active_model.assert_called_once() + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + +class TestSkforecastConfig: + """Tests for the config dataclass.""" + + def test_defaults(self): + config = EnergyLoadForecastProviderSkforecastConfig() + assert config.hours_ahead == 24 + assert config.weeks_lookback == 8 + assert config.sklearn_model == "RandomForestRegressor" + assert config.num_lags == 72 + + def test_custom_values(self): + config = EnergyLoadForecastProviderSkforecastConfig( + hours_ahead=12, weeks_lookback=4, sklearn_model="Ridge", num_lags=48 + ) + assert config.sklearn_model == "Ridge" + assert config.num_lags == 48 + + def test_is_valid(self): + config = EnergyLoadForecastProviderSkforecastConfig() + assert config.is_valid(EnergyLoadForecastProviderAdapter.SKFORECAST) is True + assert config.is_valid(EnergyLoadForecastProviderAdapter.DUMMY) is False + + def test_to_dict_from_dict_roundtrip(self): + config = EnergyLoadForecastProviderSkforecastConfig( + hours_ahead=6, sklearn_model="Lasso", num_lags=36 + ) + d = config.to_dict() + restored = EnergyLoadForecastProviderSkforecastConfig.from_dict(d) + assert restored == config + + def test_frozen(self): + config = EnergyLoadForecastProviderSkforecastConfig() + with pytest.raises(AttributeError): + config.hours_ahead = 10 # type: ignore[misc] diff --git a/tests/unit/adapters/home_load/test_typical_profile_forecast_provider.py b/tests/unit/adapters/home_load/test_typical_profile_forecast_provider.py new file mode 100644 index 0000000..122b218 --- /dev/null +++ b/tests/unit/adapters/home_load/test_typical_profile_forecast_provider.py @@ -0,0 +1,274 @@ +"""Unit tests for TypicalProfile forecast provider.""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import patch + +import pytest + +from edge_mining.adapters.domain.home_load.forecast_providers.typical_profile import ( + TypicalProfileForecastProvider, + TypicalProfileForecastProviderFactory, +) +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.exceptions import EnergyLoadForecastProviderError +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) +from edge_mining.shared.adapter_configs.home_load import ( + EnergyLoadForecastProviderDummyConfig, + EnergyLoadForecastProviderTypicalProfileConfig, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_history(weeks: int = 4, base_power: float = 300.0, ref_time: datetime = None) -> LoadEnergyConsumption: + """Build a synthetic hourly history going back ``weeks`` weeks. + + Power follows a deterministic pattern: + ``base_power + (month * 5) + (dow * 3) + (hour * 10)`` + so each (month, dow, hour) has a unique, predictable value. + """ + now = ref_time if ref_time else datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) + total_hours = weeks * 168 + intervals = [] + for i in range(total_hours, 0, -1): + start = Timestamp(now - timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(base_power + start.month * 5 + start.weekday() * 3 + start.hour * 10) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[HomeLoadPowerPoint(timestamp=start, power=power)], + energy=WattHours(float(power)), + ) + ) + return LoadEnergyConsumption(timestamp=Timestamp(now), intervals=intervals) + + +def _make_sparse_history(hours: int = 48, base_power: float = 400.0) -> LoadEnergyConsumption: + """History with only a few hours, some (month, dow, hour) combos missing.""" + now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) + intervals = [] + for i in range(hours, 0, -1): + start = Timestamp(now - timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(base_power + start.hour * 10) + intervals.append( + HomeLoadEnergyInterval( + start=start, + end=end, + power_points=[HomeLoadPowerPoint(timestamp=start, power=power)], + energy=WattHours(float(power)), + ) + ) + return LoadEnergyConsumption(timestamp=Timestamp(now), intervals=intervals) + + +# --------------------------------------------------------------------------- +# Factory tests +# --------------------------------------------------------------------------- + +class TestTypicalProfileForecastProviderFactory: + """Tests for the factory.""" + + def test_create_with_default_config(self): + factory = TypicalProfileForecastProviderFactory() + provider = factory.create(config=None, logger=None, external_service=None) + assert isinstance(provider, TypicalProfileForecastProvider) + + def test_create_with_valid_config(self): + config = EnergyLoadForecastProviderTypicalProfileConfig(hours_ahead=12, weeks_lookback=4) + factory = TypicalProfileForecastProviderFactory() + provider = factory.create(config=config, logger=None, external_service=None) + assert isinstance(provider, TypicalProfileForecastProvider) + assert provider._hours_ahead == 12 + assert provider._weeks_lookback == 4 + + def test_create_with_wrong_config_type_raises(self): + config = EnergyLoadForecastProviderDummyConfig() + factory = TypicalProfileForecastProviderFactory() + with pytest.raises(EnergyLoadForecastProviderError): + factory.create(config=config, logger=None, external_service=None) + + +# --------------------------------------------------------------------------- +# Provider tests +# --------------------------------------------------------------------------- + +class TestTypicalProfileForecastProvider: + """Tests for the provider.""" + + def test_adapter_type(self): + provider = TypicalProfileForecastProvider() + assert provider.forecast_provider_type == EnergyLoadForecastProviderAdapter.TYPICAL_PROFILE + + def test_min_required_history_hours_default(self): + provider = TypicalProfileForecastProvider(weeks_lookback=8) + assert provider.min_required_history_hours == 8 * 168 + + def test_min_required_history_hours_custom(self): + provider = TypicalProfileForecastProvider(weeks_lookback=2) + assert provider.min_required_history_hours == 2 * 168 + + def test_returns_none_for_empty_history(self): + provider = TypicalProfileForecastProvider(hours_ahead=6) + empty = LoadEnergyConsumption(timestamp=Timestamp(datetime.now(timezone.utc)), intervals=[]) + assert provider.get_consumption_forecast(empty) is None + + def test_returns_none_for_zero_hours(self): + provider = TypicalProfileForecastProvider(hours_ahead=0) + history = _make_history(4) + assert provider.get_consumption_forecast(history) is None + + def test_forecast_length_matches_hours_ahead(self): + provider = TypicalProfileForecastProvider(hours_ahead=6) + history = _make_history(4) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 6 + + def test_forecast_default_24h(self): + provider = TypicalProfileForecastProvider(hours_ahead=24) + history = _make_history(4) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 24 + + def test_forecast_uses_month_dow_hour_profile(self): + """Power should match the (month, dow, hour) average from history.""" + # Use a fixed reference time deep in a month so 4 weeks back stays in the same month. + fixed_now = datetime(2026, 7, 29, 12, 0, 0, tzinfo=timezone.utc) + provider = TypicalProfileForecastProvider(hours_ahead=6) + history = _make_history(4, base_power=300.0, ref_time=fixed_now) + with patch( + "edge_mining.adapters.domain.home_load.forecast_providers.typical_profile.datetime", + wraps=datetime, + ) as mock_dt: + mock_dt.now.return_value = fixed_now + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + + for interval in forecast.intervals: + expected = 300.0 + interval.start.month * 5 + interval.start.weekday() * 3 + interval.start.hour * 10 + assert float(interval.avg_power) == pytest.approx(expected, abs=1.0) + + def test_forecast_intervals_are_contiguous(self): + provider = TypicalProfileForecastProvider(hours_ahead=8) + history = _make_history(4) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + for i in range(len(forecast.intervals) - 1): + assert forecast.intervals[i].end == forecast.intervals[i + 1].start + + def test_forecast_power_non_negative(self): + provider = TypicalProfileForecastProvider(hours_ahead=6) + history = _make_history(4, base_power=0.0) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + for interval in forecast.intervals: + assert float(interval.avg_power) >= 0.0 + + def test_fallback_to_dow_hour_when_month_missing(self): + """When exact (month, dow, hour) isn't available, fall back to (dow, hour).""" + now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) + # Create history from a different month to force fallback + intervals = [] + different_month_start = now.replace(month=(now.month % 12) + 1, day=1) + # But that might be in the future — use past month instead + if now.month == 1: + different_month_start = now.replace(year=now.year - 1, month=12, day=1) + else: + different_month_start = now.replace(month=now.month - 1, day=1) + + for i in range(168): # 1 week in the different month + start = Timestamp(different_month_start + timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(500.0) + intervals.append( + HomeLoadEnergyInterval( + start=start, end=end, + power_points=[HomeLoadPowerPoint(timestamp=start, power=power)], + energy=WattHours(500.0), + ) + ) + history = LoadEnergyConsumption(timestamp=Timestamp(now), intervals=intervals) + + provider = TypicalProfileForecastProvider(hours_ahead=3) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 3 + # All should be 500.0 from (dow, hour) fallback + for interval in forecast.intervals: + assert float(interval.avg_power) == pytest.approx(500.0, abs=1.0) + + def test_fallback_to_global_avg_when_all_missing(self): + """When no matching (dow, hour) slots exist, uses global average.""" + now = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0) + # Create only 2 intervals at very specific times + intervals = [] + for i in [2, 1]: + start = Timestamp(now - timedelta(hours=i)) + end = Timestamp(start + timedelta(hours=1)) + power = Watts(750.0) + intervals.append( + HomeLoadEnergyInterval( + start=start, end=end, + power_points=[HomeLoadPowerPoint(timestamp=start, power=power)], + energy=WattHours(750.0), + ) + ) + sparse = LoadEnergyConsumption(timestamp=Timestamp(now), intervals=intervals) + + provider = TypicalProfileForecastProvider(hours_ahead=24) + forecast = provider.get_consumption_forecast(sparse) + assert forecast is not None + assert len(forecast.intervals) == 24 + + def test_sparse_history_still_produces_forecast(self): + """With limited history, provider should still produce valid intervals.""" + provider = TypicalProfileForecastProvider(hours_ahead=6) + history = _make_sparse_history(48) + forecast = provider.get_consumption_forecast(history) + assert forecast is not None + assert len(forecast.intervals) == 6 + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + +class TestTypicalProfileConfig: + """Tests for the config dataclass.""" + + def test_defaults(self): + config = EnergyLoadForecastProviderTypicalProfileConfig() + assert config.hours_ahead == 24 + assert config.weeks_lookback == 8 + + def test_custom_values(self): + config = EnergyLoadForecastProviderTypicalProfileConfig(hours_ahead=12, weeks_lookback=4) + assert config.hours_ahead == 12 + assert config.weeks_lookback == 4 + + def test_is_valid(self): + config = EnergyLoadForecastProviderTypicalProfileConfig() + assert config.is_valid(EnergyLoadForecastProviderAdapter.TYPICAL_PROFILE) is True + assert config.is_valid(EnergyLoadForecastProviderAdapter.DUMMY) is False + + def test_to_dict_from_dict_roundtrip(self): + config = EnergyLoadForecastProviderTypicalProfileConfig(hours_ahead=6, weeks_lookback=2) + d = config.to_dict() + restored = EnergyLoadForecastProviderTypicalProfileConfig.from_dict(d) + assert restored == config + + def test_frozen(self): + config = EnergyLoadForecastProviderTypicalProfileConfig() + with pytest.raises(AttributeError): + config.hours_ahead = 10 # type: ignore[misc] diff --git a/tests/unit/adapters/infrastructure/rule_engine/test_rule_evaluator.py b/tests/unit/adapters/infrastructure/rule_engine/test_rule_evaluator.py index 8ce0173..a008f0a 100644 --- a/tests/unit/adapters/infrastructure/rule_engine/test_rule_evaluator.py +++ b/tests/unit/adapters/infrastructure/rule_engine/test_rule_evaluator.py @@ -544,9 +544,70 @@ def test_weekend_condition_evaluation(self): result = RuleEvaluator.evaluate_rule_conditions(self.mock_context, conditions_dict) self.assertTrue(result) + # === Tests for dict key lookup in _get_field_value === + + def test_get_field_value_dict_key_lookup(self): + """Test field resolver traverses dict keys (e.g. home_load.devices.boiler).""" + self.mock_context.home_load = Mock() + self.mock_context.home_load.devices = { + "boiler": Mock(forecast=Mock(total_energy=1500.0)), + } + + result = RuleEvaluator._get_field_value( + self.mock_context, "home_load.devices.boiler.forecast.total_energy" + ) + self.assertEqual(result, 1500.0) + + def test_get_field_value_dict_key_missing(self): + """Test field resolver returns None for missing dict key.""" + self.mock_context.home_load = Mock() + self.mock_context.home_load.devices = {"boiler": Mock()} + + result = RuleEvaluator._get_field_value( + self.mock_context, "home_load.devices.fridge.forecast.total_energy" + ) + self.assertIsNone(result) + + def test_get_field_value_none_intermediate(self): + """Test field resolver returns None when intermediate value is None.""" + self.mock_context.home_load = None + + result = RuleEvaluator._get_field_value( + self.mock_context, "home_load.devices.boiler" + ) + self.assertIsNone(result) + + def test_evaluate_condition_with_dict_path(self): + """Test end-to-end evaluation through dict key path.""" + self.mock_context.home_load = Mock() + self.mock_context.home_load.total_forecast = Mock() + self.mock_context.home_load.total_forecast.avg_power = 2800.0 + + conditions_dict = { + "field": "home_load.total_forecast.avg_power", + "operator": "gt", + "value": 2500, + } + + result = RuleEvaluator.evaluate_rule_conditions(self.mock_context, conditions_dict) + self.assertTrue(result) + + def test_evaluate_condition_device_dict_path(self): + """Test evaluation using device name in dict path.""" + self.mock_context.home_load = Mock() + self.mock_context.home_load.devices = { + "boiler": Mock(forecast=Mock(next_1h=Mock(peak_power=2200.0))), + } + + conditions_dict = { + "field": "home_load.devices.boiler.forecast.next_1h.peak_power", + "operator": "gt", + "value": 2000, + } + + result = RuleEvaluator.evaluate_rule_conditions(self.mock_context, conditions_dict) + self.assertTrue(result) + if __name__ == "__main__": unittest.main() - unittest.main() - unittest.main() - unittest.main() diff --git a/tests/unit/application/events/test_policy_events.py b/tests/unit/application/events/test_policy_events.py index d72ca5a..83e9517 100644 --- a/tests/unit/application/events/test_policy_events.py +++ b/tests/unit/application/events/test_policy_events.py @@ -23,7 +23,7 @@ def test_creation_with_properties(self): energy_source=None, energy_state=None, forecast=None, - home_load_forecast=None, + home_load=None, mining_performance=None, ) event = DecisionalContextUpdatedEvent( diff --git a/tests/unit/application/services/test_configuration_event_flow.py b/tests/unit/application/services/test_configuration_event_flow.py index c1c367c..f23fc43 100644 --- a/tests/unit/application/services/test_configuration_event_flow.py +++ b/tests/unit/application/services/test_configuration_event_flow.py @@ -50,7 +50,9 @@ def mock_persistence(): "policy_repo", "optimization_unit_repo", "forecast_provider_repo", - "home_forecast_provider_repo", + "energy_load_forecast_provider_repo", + "energy_load_history_provider_repo", + "home_profile_repo", "mining_performance_tracker_repo", "notifier_repo", "settings_repo", @@ -189,7 +191,9 @@ async def test_end_to_end_cache_invalidation(mock_persistence, logger): miner_repo=mock_persistence.miner_repo, notifier_repo=mock_persistence.notifier_repo, forecast_provider_repo=mock_persistence.forecast_provider_repo, - home_forecast_provider_repo=mock_persistence.home_forecast_provider_repo, + energy_load_forecast_provider_repo=mock_persistence.energy_load_forecast_provider_repo, + energy_load_history_provider_repo=mock_persistence.energy_load_history_provider_repo, + home_load_history_repo=MagicMock(), mining_performance_tracker_repo=mock_persistence.mining_performance_tracker_repo, external_service_repo=mock_persistence.external_service_repo, event_bus=event_bus, @@ -237,7 +241,9 @@ async def test_external_service_update_clears_all_instance_cache(mock_persistenc miner_repo=mock_persistence.miner_repo, notifier_repo=mock_persistence.notifier_repo, forecast_provider_repo=mock_persistence.forecast_provider_repo, - home_forecast_provider_repo=mock_persistence.home_forecast_provider_repo, + energy_load_forecast_provider_repo=mock_persistence.energy_load_forecast_provider_repo, + energy_load_history_provider_repo=mock_persistence.energy_load_history_provider_repo, + home_load_history_repo=MagicMock(), mining_performance_tracker_repo=mock_persistence.mining_performance_tracker_repo, external_service_repo=mock_persistence.external_service_repo, event_bus=event_bus, diff --git a/tests/unit/application/services/test_configuration_service_performance.py b/tests/unit/application/services/test_configuration_service_performance.py index f4d55cf..143aa35 100644 --- a/tests/unit/application/services/test_configuration_service_performance.py +++ b/tests/unit/application/services/test_configuration_service_performance.py @@ -60,7 +60,9 @@ def mock_persistence(): "policy_repo", "optimization_unit_repo", "forecast_provider_repo", - "home_forecast_provider_repo", + "energy_load_forecast_provider_repo", + "energy_load_history_provider_repo", + "home_profile_repo", "mining_performance_tracker_repo", "notifier_repo", "settings_repo", diff --git a/tests/unit/application/services/test_home_load_history_service.py b/tests/unit/application/services/test_home_load_history_service.py new file mode 100644 index 0000000..535bf3f --- /dev/null +++ b/tests/unit/application/services/test_home_load_history_service.py @@ -0,0 +1,79 @@ +"""Unit tests for HomeLoadHistoryService.get_device_history.""" + +import uuid +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import pytest + +from edge_mining.application.services.home_load_history_service import HomeLoadHistoryService +from edge_mining.domain.common import EntityId, Timestamp, Watts +from edge_mining.domain.home_load.value_objects import HomeLoadPowerPoint + + +@pytest.fixture +def mock_home_loads_repo(): + return MagicMock() + + +@pytest.fixture +def mock_history_repo(): + return MagicMock() + + +@pytest.fixture +def mock_adapter_service(): + return MagicMock() + + +@pytest.fixture +def logger(): + mock = MagicMock() + mock.debug = MagicMock() + mock.info = MagicMock() + mock.warning = MagicMock() + mock.error = MagicMock() + return mock + + +@pytest.fixture +def service(mock_home_loads_repo, mock_history_repo, mock_adapter_service, logger): + return HomeLoadHistoryService( + home_loads_repo=mock_home_loads_repo, + home_load_history_repo=mock_history_repo, + adapter_service=mock_adapter_service, + event_bus=None, + logger=logger, + ) + + +class TestGetDeviceHistory: + def test_returns_power_points_from_repo(self, service, mock_history_repo): + device_id = EntityId(uuid.uuid4()) + now = datetime.now() + start = Timestamp(now - timedelta(hours=24)) + end = Timestamp(now) + + expected_points = [ + HomeLoadPowerPoint(timestamp=Timestamp(now - timedelta(hours=2)), power=Watts(100.0)), + HomeLoadPowerPoint(timestamp=Timestamp(now - timedelta(hours=1)), power=Watts(150.0)), + ] + mock_history_repo.get_power_points.return_value = expected_points + + result = service.get_device_history(device_id, start, end) + + assert result == expected_points + mock_history_repo.get_power_points.assert_called_once_with(device_id, start, end) + + def test_returns_empty_list_when_no_data(self, service, mock_history_repo): + device_id = EntityId(uuid.uuid4()) + now = datetime.now() + start = Timestamp(now - timedelta(hours=24)) + end = Timestamp(now) + + mock_history_repo.get_power_points.return_value = [] + + result = service.get_device_history(device_id, start, end) + + assert result == [] + mock_history_repo.get_power_points.assert_called_once_with(device_id, start, end) diff --git a/tests/unit/application/services/test_load_forecast_training_service.py b/tests/unit/application/services/test_load_forecast_training_service.py new file mode 100644 index 0000000..e2ea813 --- /dev/null +++ b/tests/unit/application/services/test_load_forecast_training_service.py @@ -0,0 +1,123 @@ +"""Unit tests for LoadForecastModelTrainingService.train_device and get_models.""" + +import uuid +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from edge_mining.application.services.load_forecast_training_service import LoadForecastModelTrainingService +from edge_mining.domain.common import EntityId +from edge_mining.domain.home_load.aggregate_roots import HomeLoadsProfile +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.entities import LoadConsumptionModel, LoadDevice + + +@pytest.fixture +def mock_home_loads_repo(): + return MagicMock() + + +@pytest.fixture +def mock_history_repo(): + return MagicMock() + + +@pytest.fixture +def mock_model_repo(): + return MagicMock() + + +@pytest.fixture +def logger(): + mock = MagicMock() + mock.debug = MagicMock() + mock.info = MagicMock() + mock.warning = MagicMock() + mock.error = MagicMock() + return mock + + +@pytest.fixture +def service(mock_home_loads_repo, mock_history_repo, mock_model_repo, logger): + return LoadForecastModelTrainingService( + home_loads_repo=mock_home_loads_repo, + history_repo=mock_history_repo, + model_repo=mock_model_repo, + logger=logger, + ) + + +@pytest.fixture +def device_id() -> EntityId: + return EntityId(uuid.uuid4()) + + +@pytest.fixture +def profile_with_device(device_id): + device = LoadDevice(id=device_id, name="Dishwasher", enabled=True) + profile = HomeLoadsProfile(name="Test Home", devices=[device]) + return profile + + +class TestTrainDevice: + @pytest.mark.asyncio + async def test_train_device_calls_train_for_device(self, service, mock_home_loads_repo, device_id, profile_with_device): + mock_home_loads_repo.get_all.return_value = [profile_with_device] + + with patch.object(service, "_train_for_device", new_callable=AsyncMock) as mock_train: + await service.train_device(device_id) + mock_train.assert_awaited_once_with(device_id, "Dishwasher", 8) + + @pytest.mark.asyncio + async def test_train_device_with_custom_lookback(self, service, mock_home_loads_repo, device_id, profile_with_device): + mock_home_loads_repo.get_all.return_value = [profile_with_device] + + with patch.object(service, "_train_for_device", new_callable=AsyncMock) as mock_train: + await service.train_device(device_id, weeks_lookback=4) + mock_train.assert_awaited_once_with(device_id, "Dishwasher", 4) + + @pytest.mark.asyncio + async def test_train_device_unknown_device_skips(self, service, mock_home_loads_repo, logger): + mock_home_loads_repo.get_all.return_value = [] + unknown_id = EntityId(uuid.uuid4()) + + with patch.object(service, "_train_for_device", new_callable=AsyncMock) as mock_train: + await service.train_device(unknown_id) + mock_train.assert_not_awaited() + logger.warning.assert_called_once() + + @pytest.mark.asyncio + async def test_train_device_finds_device_across_profiles(self, service, mock_home_loads_repo, device_id): + device = LoadDevice(id=device_id, name="Target", enabled=True) + profile1 = HomeLoadsProfile(name="Home 1", devices=[]) + profile2 = HomeLoadsProfile(name="Home 2", devices=[device]) + mock_home_loads_repo.get_all.return_value = [profile1, profile2] + + with patch.object(service, "_train_for_device", new_callable=AsyncMock) as mock_train: + await service.train_device(device_id) + mock_train.assert_awaited_once_with(device_id, "Target", 8) + + +class TestGetModels: + def test_get_models_delegates_to_repo(self, service, mock_model_repo): + expected = [ + LoadConsumptionModel( + adapter_type=EnergyLoadForecastProviderAdapter.STATSMODELS, + mae=1.0, + is_active=True, + ) + ] + mock_model_repo.get_all.return_value = expected + + result = service.get_models() + + assert result == expected + mock_model_repo.get_all.assert_called_once_with(None) + + def test_get_models_with_device_filter(self, service, mock_model_repo, device_id): + mock_model_repo.get_all.return_value = [] + + service.get_models(device_id=device_id) + + mock_model_repo.get_all.assert_called_once_with(device_id) diff --git a/tests/unit/domain/home_load/__init__.py b/tests/unit/domain/home_load/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/domain/home_load/test_load_consumption_model_repository.py b/tests/unit/domain/home_load/test_load_consumption_model_repository.py new file mode 100644 index 0000000..9b4e6d1 --- /dev/null +++ b/tests/unit/domain/home_load/test_load_consumption_model_repository.py @@ -0,0 +1,85 @@ +"""Unit tests for InMemoryLoadConsumptionModelRepository.get_all.""" + +import uuid + +import pytest + +from edge_mining.adapters.domain.home_load.repositories import InMemoryLoadConsumptionModelRepository +from edge_mining.domain.common import EntityId +from edge_mining.domain.home_load.common import EnergyLoadForecastProviderAdapter +from edge_mining.domain.home_load.entities import LoadConsumptionModel + + +@pytest.fixture +def repo() -> InMemoryLoadConsumptionModelRepository: + return InMemoryLoadConsumptionModelRepository() + + +@pytest.fixture +def device_id_a() -> EntityId: + return EntityId(uuid.uuid4()) + + +@pytest.fixture +def device_id_b() -> EntityId: + return EntityId(uuid.uuid4()) + + +def _make_model( + device_id: EntityId, + adapter: EnergyLoadForecastProviderAdapter = EnergyLoadForecastProviderAdapter.STATSMODELS, + is_active: bool = False, +) -> LoadConsumptionModel: + return LoadConsumptionModel( + device_id=device_id, + adapter_type=adapter, + is_active=is_active, + mae=1.5, + samples_used=100, + ) + + +class TestInMemoryLoadConsumptionModelGetAll: + def test_get_all_empty(self, repo): + assert repo.get_all() == [] + + def test_get_all_returns_all_models(self, repo, device_id_a, device_id_b): + m1 = _make_model(device_id_a) + m2 = _make_model(device_id_b, adapter=EnergyLoadForecastProviderAdapter.XGBOOST) + repo.add(m1) + repo.add(m2) + + result = repo.get_all() + assert len(result) == 2 + result_ids = {str(m.id) for m in result} + assert str(m1.id) in result_ids + assert str(m2.id) in result_ids + + def test_get_all_filtered_by_device_id(self, repo, device_id_a, device_id_b): + m1 = _make_model(device_id_a) + m2 = _make_model(device_id_b) + m3 = _make_model(device_id_a, adapter=EnergyLoadForecastProviderAdapter.XGBOOST) + repo.add(m1) + repo.add(m2) + repo.add(m3) + + result = repo.get_all(device_id=device_id_a) + assert len(result) == 2 + for m in result: + assert str(m.device_id) == str(device_id_a) + + def test_get_all_filtered_returns_empty_for_unknown_device(self, repo, device_id_a): + repo.add(_make_model(device_id_a)) + unknown = EntityId(uuid.uuid4()) + assert repo.get_all(device_id=unknown) == [] + + def test_get_all_returns_deep_copies(self, repo, device_id_a): + m1 = _make_model(device_id_a) + repo.add(m1) + + result = repo.get_all() + assert len(result) == 1 + result[0].mae = 999.0 + # Original should be unchanged + original = repo.get_by_id(m1.id) + assert original.mae == 1.5 diff --git a/tests/unit/domain/home_load/test_load_energy_consumption_mix.py b/tests/unit/domain/home_load/test_load_energy_consumption_mix.py new file mode 100644 index 0000000..661140a --- /dev/null +++ b/tests/unit/domain/home_load/test_load_energy_consumption_mix.py @@ -0,0 +1,120 @@ +"""Unit tests for LoadEnergyConsumption.mix() — α/β forecast blending.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) + + +def _ts(offset_hours: int = 0) -> Timestamp: + """Utility: create a UTC timestamp offset from a fixed base.""" + base = datetime(2026, 4, 25, 10, 0, 0, tzinfo=timezone.utc) + return Timestamp(base + timedelta(hours=offset_hours)) + + +def _interval(start_h: int, end_h: int, power: float) -> HomeLoadEnergyInterval: + """Build a 1-hour interval with a single power point and pre-computed energy.""" + start = _ts(start_h) + end = _ts(end_h) + pp = HomeLoadPowerPoint(timestamp=start, power=Watts(power)) + duration_hours = (end - start).total_seconds() / 3600.0 + return HomeLoadEnergyInterval( + start=start, + end=end, + energy=WattHours(power * duration_hours), + power_points=[pp], + ) + + +def _consumption(*powers: float) -> LoadEnergyConsumption: + """Build a LoadEnergyConsumption with N hourly intervals starting at hour 0.""" + intervals = [_interval(i, i + 1, p) for i, p in enumerate(powers)] + return LoadEnergyConsumption(timestamp=_ts(0), intervals=intervals) + + +class TestMixForecast: + def test_empty_forecast_returns_unchanged(self): + empty = LoadEnergyConsumption(timestamp=_ts(0), intervals=[]) + result = LoadEnergyConsumption.mix(empty, Watts(500.0)) + assert result.intervals == [] + + def test_default_alpha_beta_equal_weights(self): + forecast = _consumption(200.0, 300.0, 400.0) + last_real = Watts(100.0) + + result = LoadEnergyConsumption.mix(forecast, last_real) + + # First interval: 0.5 * 200 + 0.5 * 100 = 150 + assert result.intervals[0].avg_power == pytest.approx(200.0) # power_points unchanged + assert float(result.intervals[0].energy) == pytest.approx(150.0) # blended energy + # Remaining intervals unchanged + assert float(result.intervals[1].energy) == float(forecast.intervals[1].energy) + assert float(result.intervals[2].energy) == float(forecast.intervals[2].energy) + + def test_alpha_1_beta_0_keeps_forecast(self): + forecast = _consumption(200.0, 300.0) + last_real = Watts(999.0) + + result = LoadEnergyConsumption.mix(forecast, last_real, alpha=1.0, beta=0.0) + + # 1.0 * 200 + 0.0 * 999 = 200 + assert float(result.intervals[0].energy) == pytest.approx(200.0) + + def test_alpha_0_beta_1_uses_real_only(self): + forecast = _consumption(200.0, 300.0) + last_real = Watts(500.0) + + result = LoadEnergyConsumption.mix(forecast, last_real, alpha=0.0, beta=1.0) + + # 0.0 * 200 + 1.0 * 500 = 500 + assert float(result.intervals[0].energy) == pytest.approx(500.0) + + def test_custom_weights(self): + forecast = _consumption(200.0, 300.0) + last_real = Watts(100.0) + + result = LoadEnergyConsumption.mix(forecast, last_real, alpha=0.25, beta=0.75) + + # 0.25 * 200 + 0.75 * 100 = 50 + 75 = 125 + assert float(result.intervals[0].energy) == pytest.approx(125.0) + + def test_single_interval_forecast(self): + forecast = _consumption(400.0) + last_real = Watts(200.0) + + result = LoadEnergyConsumption.mix(forecast, last_real) + + # 0.5 * 400 + 0.5 * 200 = 300 + assert float(result.intervals[0].energy) == pytest.approx(300.0) + assert len(result.intervals) == 1 + + def test_preserves_timestamp(self): + forecast = _consumption(100.0, 200.0) + result = LoadEnergyConsumption.mix(forecast, Watts(50.0)) + assert result.timestamp == forecast.timestamp + + def test_preserves_power_points(self): + forecast = _consumption(100.0, 200.0) + result = LoadEnergyConsumption.mix(forecast, Watts(50.0)) + assert result.intervals[0].power_points == forecast.intervals[0].power_points + + def test_preserves_interval_times(self): + forecast = _consumption(100.0, 200.0, 300.0) + result = LoadEnergyConsumption.mix(forecast, Watts(50.0)) + for i in range(3): + assert result.intervals[i].start == forecast.intervals[i].start + assert result.intervals[i].end == forecast.intervals[i].end + + def test_does_not_mutate_original(self): + forecast = _consumption(200.0, 300.0) + original_energy = float(forecast.intervals[0].energy) + + LoadEnergyConsumption.mix(forecast, Watts(100.0)) + + assert float(forecast.intervals[0].energy) == original_energy diff --git a/tests/unit/domain/home_load/test_load_energy_consumption_windows.py b/tests/unit/domain/home_load/test_load_energy_consumption_windows.py new file mode 100644 index 0000000..e0b46a5 --- /dev/null +++ b/tests/unit/domain/home_load/test_load_energy_consumption_windows.py @@ -0,0 +1,152 @@ +"""Unit tests for LoadEnergyConsumption extended window properties (F2 — 24h horizon).""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from edge_mining.domain.common import Timestamp, WattHours, Watts +from edge_mining.domain.home_load.value_objects import ( + HomeLoadEnergyInterval, + HomeLoadPowerPoint, + LoadEnergyConsumption, +) + +# Fixed "now" used as anchor for all tests +_NOW = datetime(2026, 4, 25, 12, 0, 0, tzinfo=timezone.utc) + + +def _ts(offset_hours: float = 0) -> Timestamp: + return Timestamp(_NOW + timedelta(hours=offset_hours)) + + +def _interval(start_h: float, end_h: float, power: float = 100.0) -> HomeLoadEnergyInterval: + start = _ts(start_h) + end = _ts(end_h) + pp = HomeLoadPowerPoint(timestamp=start, power=Watts(power)) + duration_hours = (end - start).total_seconds() / 3600.0 + return HomeLoadEnergyInterval( + start=start, + end=end, + energy=WattHours(power * duration_hours), + power_points=[pp], + ) + + +def _make_24h_forecast() -> LoadEnergyConsumption: + """Build a 24-hour forecast with one interval per hour starting at _NOW.""" + intervals = [_interval(h, h + 1, power=float(100 + h * 10)) for h in range(24)] + return LoadEnergyConsumption(timestamp=_ts(0), intervals=intervals) + + +def _make_24h_history() -> LoadEnergyConsumption: + """Build a 24-hour history ending at _NOW.""" + intervals = [_interval(-24 + h, -23 + h, power=float(50 + h * 5)) for h in range(24)] + return LoadEnergyConsumption(timestamp=_ts(0), intervals=intervals) + + +class TestExtendedNextWindowProperties: + """Verify next_Xh properties return correct number of intervals.""" + + def test_next_1h(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(1, now=_ts(0)) + assert len(subset.intervals) == 1 + + def test_next_2h(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(2, now=_ts(0)) + assert len(subset.intervals) == 2 + + def test_next_4h(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(4, now=_ts(0)) + assert len(subset.intervals) == 4 + + def test_next_6h(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(6, now=_ts(0)) + assert len(subset.intervals) == 6 + + def test_next_8h(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(8, now=_ts(0)) + assert len(subset.intervals) == 8 + + def test_next_12h(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(12, now=_ts(0)) + assert len(subset.intervals) == 12 + + def test_next_24h(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(24, now=_ts(0)) + assert len(subset.intervals) == 24 + + def test_next_24h_total_energy(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(24, now=_ts(0)) + assert float(subset.total_energy) == float(forecast.total_energy) + + +class TestExtendedLastWindowProperties: + """Verify last_Xh properties return correct number of intervals.""" + + def test_last_1h(self): + history = _make_24h_history() + subset = history.in_last_hours(1, now=_ts(0)) + assert len(subset.intervals) == 1 + + def test_last_4h(self): + history = _make_24h_history() + subset = history.in_last_hours(4, now=_ts(0)) + assert len(subset.intervals) == 4 + + def test_last_12h(self): + history = _make_24h_history() + subset = history.in_last_hours(12, now=_ts(0)) + assert len(subset.intervals) == 12 + + def test_last_24h(self): + history = _make_24h_history() + subset = history.in_last_hours(24, now=_ts(0)) + assert len(subset.intervals) == 24 + + def test_last_24h_total_energy(self): + history = _make_24h_history() + subset = history.in_last_hours(24, now=_ts(0)) + assert float(subset.total_energy) == float(history.total_energy) + + +class TestWindowAggregates: + """Verify energy/power aggregates on windowed subsets.""" + + def test_next_6h_avg_power(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(6, now=_ts(0)) + # Intervals 0..5 have power 100,110,120,130,140,150 → avg = 125 + assert float(subset.avg_power) == pytest.approx(125.0) + + def test_next_6h_peak_power(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(6, now=_ts(0)) + assert float(subset.peak_power) == pytest.approx(150.0) + + def test_next_12h_total_energy(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(12, now=_ts(0)) + # Powers: 100,110,...,210 → sum = 12 * 100 + 10*(0+1+..+11) = 1200+660 = 1860 + # Each 1h interval → energy = power, so total = 1860 + assert float(subset.total_energy) == pytest.approx(1860.0) + + def test_partial_window_returns_overlapping_intervals(self): + """Window that starts mid-interval still returns that interval.""" + forecast = _make_24h_forecast() + subset = forecast.in_window(_ts(0.5), _ts(2.5)) + # Intervals [0,1), [1,2), [2,3) all overlap [0.5, 2.5) + assert len(subset.intervals) == 3 + + def test_empty_window(self): + forecast = _make_24h_forecast() + subset = forecast.in_next_hours(1, now=_ts(100)) + assert len(subset.intervals) == 0 + assert float(subset.total_energy) == 0.0