diff --git a/simple/stats/jsonld_stream_db.py b/simple/stats/jsonld_stream_db.py new file mode 100644 index 00000000..3c9e2236 --- /dev/null +++ b/simple/stats/jsonld_stream_db.py @@ -0,0 +1,384 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A DB implementation that streams JSON-LD shards directly to GCS/Disk.""" + +import concurrent.futures +from datetime import datetime +from datetime import timezone +import gc +import hashlib +import json +import logging +import multiprocessing +import os +import tempfile +import threading + +from google.cloud import storage +import pandas as pd +from rdflib import Graph +from rdflib import Literal +from rdflib import Namespace +from rdflib import RDF +import requests +from stats import constants +from stats.data import strip_namespace +from stats.data import Triple +from stats.db import Db +from stats.jsonld_exporter import DCID_URL +from stats.jsonld_exporter import expand_id +from stats.jsonld_exporter import write_shard +from util.filesystem import create_store +from util.filesystem import Dir +from util.filesystem import File + +# Configuration Constants +_CHUNK_SIZE = 10000 +_UPLOAD_CONCURRENCY = 32 +_EXPORT_PROCESSES_MAX = 8 + + +def _uri_ref(val): + if not val: + return None + if val.startswith(("http://", "https://", "dcid:")): + return {"@id": val} + return {"@id": f"dcid:{val.lstrip('/')}"} + + +def _parse_numeric(val): + if val is None or val == "": + return None + try: + if "." in str(val): + return float(val) + return int(val) + except ValueError: + return str(val) + + +def _write_observation_shard(args): + chunk, shard_index, jsonld_dir_path, ns_map, prov_urls = args + graph_list = [] + + for row in chunk: + entity, variable, date, value, provenance, unit, scaling_factor, mmethod, period, props = row + + key = f"{entity}_{variable}_{date}_{provenance}_{unit}_{mmethod}_{period}" + obs_hash = hashlib.sha256(key.encode('utf-8')).hexdigest() + + obs_obj = { + "@id": f"dcid:obs_{obs_hash}", + "@type": "dcid:StatVarObservation", + "dcid:observationAbout": _uri_ref(entity), + "dcid:variableMeasured": _uri_ref(variable), + "dcid:observationDate": _parse_numeric(date), + "dcid:value": _parse_numeric(value), + } + + if provenance: + obs_obj["dcid:provenance"] = _uri_ref(provenance) + if provenance in prov_urls and prov_urls[provenance]: + obs_obj["dcid:provenanceUrl"] = prov_urls[provenance] + if unit: + obs_obj["dcid:unit"] = _uri_ref(unit) + if scaling_factor: + obs_obj["dcid:scalingFactor"] = _parse_numeric(scaling_factor) + if mmethod: + obs_obj["dcid:measurementMethod"] = _uri_ref(mmethod) + if period: + obs_obj["dcid:observationPeriod"] = period + + if props: + try: + props_dict = json.loads(props) + for k, v in props_dict.items(): + prop_key = f"dcid:{k}" if not k.startswith( + "dcid:") and not k.startswith("http") else k + obs_obj[prop_key] = v + except json.JSONDecodeError as e: + logging.warning( + "Failed to decode properties JSON for observation %s/%s: %s", + entity, variable, e) + + graph_list.append(obs_obj) + + compacted_jsonld = {"@context": ns_map, "@graph": graph_list} + + shard_name = f"observation-{shard_index:05d}.jsonld" + with create_store(jsonld_dir_path) as store: + output_dir = store.as_dir() + output_dir.open_file(shard_name).write( + json.dumps(compacted_jsonld, indent=4)) + logging.info(f"Saved JSON-LD shard to {shard_name}") + + +def _write_node_shard(args): + # TODO(gmechali): Get rid of this and keep only the "fast" mode. + fast_export = os.getenv("FAST_NODE_EXPORT", + "true").lower() in ("true", "1", "yes") + if fast_export: + _write_node_shard_fast(args) + else: + _write_node_shard_rdflib(args) + + +def _write_node_shard_fast(args): + chunk, shard_index, jsonld_dir_path, ns_map = args + subjects = {} + + for row in chunk: + sub_id = row.subject_id + if sub_id not in subjects: + subjects[sub_id] = { + "@id": + f"dcid:{sub_id.lstrip('/')}" if not sub_id.startswith("http") and + not sub_id.startswith("dcid:") else sub_id + } + + pred = row.predicate + pred_key = f"dcid:{pred}" if not pred.startswith( + "dcid:") and not pred.startswith("http") else pred + + if pred == "typeOf": + pred_key = "@type" + + if row.object_id: + val = _uri_ref(row.object_id) + else: + val = _parse_numeric(row.object_value) + + if pred_key == "@type": + val_str = val["@id"] if isinstance(val, + dict) and "@id" in val else str(val) + if "@type" in subjects[sub_id]: + existing = subjects[sub_id]["@type"] + if isinstance(existing, list): + if val_str not in existing: + existing.append(val_str) + elif existing != val_str: + subjects[sub_id]["@type"] = [existing, val_str] + else: + subjects[sub_id]["@type"] = val_str + else: + if pred_key in subjects[sub_id]: + existing = subjects[sub_id][pred_key] + if isinstance(existing, list): + if val not in existing: + existing.append(val) + elif existing != val: + subjects[sub_id][pred_key] = [existing, val] + else: + subjects[sub_id][pred_key] = val + + # Sort by @id to match rdflib output order + graph_list = sorted(list(subjects.values()), key=lambda x: x["@id"]) + compacted_jsonld = {"@context": ns_map, "@graph": graph_list} + + shard_name = f"node-{shard_index:05d}.jsonld" + with create_store(jsonld_dir_path) as store: + output_dir = store.as_dir() + output_dir.open_file(shard_name).write( + json.dumps(compacted_jsonld, indent=4)) + logging.info(f"Saved JSON-LD shard to {shard_name} (fast path)") + + +def _write_node_shard_rdflib(args): + """ + Writes a chunk of triples to a JSON-LD shard using rdflib. + Args: + args: Tuple containing (chunk, shard_index, jsonld_dir_path, ns_map) + """ + + # TODO(gmechali): Completely deprecate this path after we have 100% certainty in the direct export. + # note that this path is exponentially slower. + chunk, shard_index, jsonld_dir_path, ns_map = args + DCID = Namespace(DCID_URL) + g = Graph() + g.bind("dcid", DCID) + + for row in chunk: + sub = expand_id(row.subject_id) + p = expand_id(row.predicate) + if row.object_id: + o = expand_id(row.object_id) + else: + o = Literal(row.object_value) + + if row.predicate == 'typeOf': + g.add((sub, RDF.type, o)) + else: + g.add((sub, p, o)) + + with create_store(jsonld_dir_path) as store: + output_dir = store.as_dir() + write_shard(g, shard_index, output_dir, ns_map, prefix="node") + + +class JsonLdStreamDb(Db): + """A DB implementation that streams triples and observations directly to JSON-LD shards on GCS/Disk.""" + + def __init__(self, output_dir, import_names, nodes) -> None: + self.output_dir = output_dir + self.import_names = import_names + self.nodes = nodes + + # Generate unique folder name based on import name and timestamp + import_name = None + if isinstance(import_names, list): + if import_names == [constants.ALL_IMPORTS]: + import_name = constants.ALL_IMPORTS + else: + import_name = "_".join(import_names) + + self.import_name = import_name or nodes.config.data.get( + "importName") or "default_import_name" + if self.import_name and "/" in self.import_name: + self.import_name = self.import_name.replace("/", "_") + + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f") + unique_dir_name = f"{self.import_name}_{timestamp}" + self.jsonld_dir = output_dir.open_dir("jsonld").open_dir(unique_dir_name) + + self.obs_shard_index = 0 + self.node_shard_index = 0 + self.ns_map = {"dcid": DCID_URL} + self.lock = threading.Lock() + self._obs_records = [] + self._triples = [] + + def insert_observations(self, observations_df: pd.DataFrame, + input_file: File): + if not observations_df.empty: + records = observations_df.to_records(index=False).tolist() + with self.lock: + self._obs_records.extend(records) + + def insert_triples(self, triples: list[Triple]): + if triples: + with self.lock: + self._triples.extend(triples) + + def commit(self): + pass + + def commit_and_close(self): + num_processes = min(multiprocessing.cpu_count(), _EXPORT_PROCESSES_MAX) + + with tempfile.TemporaryDirectory() as temp_local_dir: + logging.info("Using local temporary directory for export buffering: %s", + temp_local_dir) + + if self._obs_records or self._triples: + logging.info( + "Starting JSON-LD local export with %d processes in streaming mode", + num_processes) + with multiprocessing.Pool(processes=num_processes) as pool: + if self._obs_records: + logging.info("Streaming observations export...") + obs_gen = self._generate_observation_chunks(temp_local_dir) + for _ in pool.imap(_write_observation_shard, obs_gen): + pass + + if self._triples: + logging.info("Streaming triples export...") + node_gen = self._generate_node_chunks(temp_local_dir) + for _ in pool.imap(_write_node_shard, node_gen): + pass + + self._upload_shards(temp_local_dir) + + def _generate_observation_chunks(self, temp_local_dir: str): + """Generates observation chunks of size _CHUNK_SIZE, cleaning memory dynamically.""" + prov_urls = {} + for prov in self.nodes.provenances.values(): + prov_id = strip_namespace(prov.id) + prov_urls[prov_id] = prov.url + prov_urls[prov.id] = prov.url + + num_records = len(self._obs_records) + for idx in range(0, num_records, _CHUNK_SIZE): + chunk = self._obs_records[idx:idx + _CHUNK_SIZE] + yield (chunk, self.obs_shard_index, temp_local_dir, self.ns_map, + prov_urls) + self.obs_shard_index += 1 + self._obs_records.clear() + + def _generate_node_chunks(self, temp_local_dir: str): + """Generates node chunks of size _CHUNK_SIZE.""" + num_triples = len(self._triples) + for idx in range(0, num_triples, _CHUNK_SIZE): + chunk = self._triples[idx:idx + _CHUNK_SIZE] + yield (chunk, self.node_shard_index, temp_local_dir, self.ns_map) + self.node_shard_index += 1 + self._triples.clear() + + def _upload_shards(self, temp_local_dir: str): + """Uploads files in temp_local_dir to jsonld_dir, optimizing for GCS via native SDK.""" + files_to_upload = sorted(os.listdir(temp_local_dir)) + if not files_to_upload: + return + + target_path = self.jsonld_dir.full_path() + logging.info( + "Bulk uploading %d JSON-LD shards to target directory %s in parallel", + len(files_to_upload), target_path) + + if target_path.startswith("gs://"): + self._upload_shards_gcs(temp_local_dir, files_to_upload, target_path) + else: + self._upload_shards_local(temp_local_dir, files_to_upload) + + logging.info("Bulk upload of JSON-LD shards completed successfully.") + + def _upload_shards_gcs(self, temp_local_dir: str, files: list[str], + target_path: str): + """Performs concurrent GCS uploads using native google-cloud-storage client.""" + # Parse bucket and blob prefix + parts = target_path[5:].split("/", 1) + bucket_name = parts[0] + blob_prefix = parts[1].rstrip("/") if len(parts) > 1 else "" + + client = storage.Client() + + # Configure connection pool size for concurrent GCS uploads + adapter = requests.adapters.HTTPAdapter( + pool_connections=_UPLOAD_CONCURRENCY, pool_maxsize=_UPLOAD_CONCURRENCY) + client._http.mount("https://", adapter) + client._http.mount("http://", adapter) + + bucket = client.bucket(bucket_name) + + def _upload_single(filename: str): + local_file_path = os.path.join(temp_local_dir, filename) + blob_key = f"{blob_prefix}/{filename}" if blob_prefix else filename + blob = bucket.blob(blob_key) + blob.upload_from_filename(local_file_path) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=_UPLOAD_CONCURRENCY) as executor: + list(executor.map(_upload_single, files)) + + def _upload_shards_local(self, temp_local_dir: str, files: list[str]): + """Performs concurrent local file copy (for test environments).""" + local_store = create_store(temp_local_dir).as_dir() + target_store = self.jsonld_dir + + def _copy_single(filename: str): + content = local_store.open_file(filename).read() + target_store.open_file(filename).write(content) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=_UPLOAD_CONCURRENCY) as executor: + list(executor.map(_copy_single, files)) diff --git a/simple/stats/logger.py b/simple/stats/logger.py index c15c1a8a..c46d4cdd 100644 --- a/simple/stats/logger.py +++ b/simple/stats/logger.py @@ -36,7 +36,6 @@ def initialize_logger(): for handler in logging.root.handlers: logging.root.removeHandler(handler) - # Initialize logging logger = logging.getLogger() logger.setLevel(log_level) handler = logging.StreamHandler(sys.stdout) diff --git a/simple/stats/main.py b/simple/stats/main.py index 1eef195c..1c271763 100644 --- a/simple/stats/main.py +++ b/simple/stats/main.py @@ -18,6 +18,7 @@ from absl import app from absl import flags from freezegun import freeze_time +import requests.adapters from stats import constants from stats.logger import initialize_logger from stats.runner import RunMode @@ -55,6 +56,9 @@ def _run(): + # Configure requests adapter default pool size to support parallel GCS uploads + requests.adapters.DEFAULT_POOLSIZE = 32 + initialize_logger() logging.info("Starting stats data importer job in mode: %s", FLAGS.mode) diff --git a/simple/stats/nodes.py b/simple/stats/nodes.py index df31d125..5c6f49a8 100644 --- a/simple/stats/nodes.py +++ b/simple/stats/nodes.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import wraps import logging import re +import threading import pandas as pd from stats.config import Config @@ -52,9 +54,21 @@ url="custom-import") +def thread_safe(func): + """Decorator to make a method thread-safe using the object's reentrant lock.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + + return wrapper + + class Nodes: def __init__(self, config: Config) -> None: + self.lock = threading.RLock() self.config = config # Custom namespace self._custom_id_namespace = self.config.custom_id_namespace() @@ -122,10 +136,12 @@ def _source_id(self, source_cfg: Source | None) -> str: return source.id + @thread_safe def provenance(self, input_file: File) -> Provenance: prov_name = self.config.provenance_name(input_file) return self.provenances.get(prov_name, _DEFAULT_PROVENANCE) + @thread_safe def variable(self, sv_column_name: str, input_file: File) -> StatVar: if not sv_column_name in self.variables: var_cfg = self.config.variable(sv_column_name) @@ -142,6 +158,7 @@ def variable(self, sv_column_name: str, input_file: File) -> StatVar: return self._add_provenance(self.variables[sv_column_name], self.provenance(input_file)) + @thread_safe def property(self, property_column_name: str) -> Property: if not property_column_name in self.properties: self.properties[property_column_name] = Property( @@ -149,6 +166,7 @@ def property(self, property_column_name: str) -> Property: return self.properties[property_column_name] + @thread_safe def event_type(self, event_type_name: str, input_file: File) -> EventType: if not event_type_name in self.event_types: event_type_cfg = self.config.event(event_type_name) @@ -160,6 +178,7 @@ def event_type(self, event_type_name: str, input_file: File) -> EventType: return self.event_types[event_type_name].add_provenance( self.provenance(input_file)) + @thread_safe def entity_type(self, entity_type_name: str, input_file: File) -> EntityType: if not entity_type_name in self.entity_types: entity_type_cfg = self.config.entity(entity_type_name) @@ -227,6 +246,7 @@ def _entity_type_id(self, entity_type_name: str) -> str: self._entity_type_generated_id_count += 1 return f"{_CUSTOM_ENTITY_TYPE_ID_PREFIX}{self._entity_type_generated_id_count}" + @thread_safe def group(self, group_path: str) -> StatVarGroup | None: if not group_path: return self._default_custom_group() @@ -257,14 +277,17 @@ def _default_custom_group(self) -> StatVarGroup: self.groups[_DEFAULT_CUSTOM_GROUP_PATH] = svg return self.groups[_DEFAULT_CUSTOM_GROUP_PATH] + @thread_safe def entity_with_type(self, entity_dcid: str, entity_type: str): if entity_dcid not in self.entities: self.entities[entity_dcid] = Entity(entity_dcid, entity_type) + @thread_safe def entities_with_type(self, entity_dcids: list[str], entity_type: str): for entity_dcid in entity_dcids: self.entity_with_type(entity_dcid, entity_type) + @thread_safe def entities_with_types(self, dcid2type: dict[str, str]): """ Adds each dcid2type mapping to the list of entities with their types. @@ -273,6 +296,7 @@ def entities_with_types(self, dcid2type: dict[str, str]): for entity_dcid, entity_type in dcid2type.items(): self.entity_with_type(entity_dcid, entity_type) + @thread_safe def triples(self, triples_file: File | None = None) -> list[Triple]: triples: list[Triple] = [] for source in self.sources.values(): diff --git a/simple/stats/reporter.py b/simple/stats/reporter.py index 51232dd3..47e36afc 100644 --- a/simple/stats/reporter.py +++ b/simple/stats/reporter.py @@ -17,6 +17,7 @@ from enum import Enum from functools import wraps import json +import threading import time from util.filesystem import File @@ -44,6 +45,7 @@ class ImportReporter: """ def __init__(self, report_file: File) -> None: + self.lock = threading.RLock() self.status = Status.NOT_STARTED self.start_time = None self.last_update = datetime.now() @@ -58,8 +60,9 @@ def _report(func): @wraps(func) def wrapper(self, *args, **kwargs): - result = func(self, *args, **kwargs) - ImportReporter.save(self) + with self.lock: + result = func(self, *args, **kwargs) + ImportReporter.save(self) return result return wrapper @@ -87,8 +90,9 @@ def get_file_reporter(self, import_file: File): return self.file_reporters_by_full_path[import_file.full_path()] def recompute_progress(self): - self._compute_all_done() - self.save() + with self.lock: + self._compute_all_done() + self.save() def _compute_all_done(self): if self._all_file_imports(Status.SUCCESS): @@ -152,8 +156,9 @@ def _report(func): @wraps(func) def wrapper(self, *args, **kwargs): - result = func(self, *args, **kwargs) - FileImportReporter.report(self) + with self.parent.lock: + result = func(self, *args, **kwargs) + FileImportReporter.report(self) return result return wrapper @@ -169,7 +174,7 @@ def report_success(self): @_report def report_failure(self, error: str): - self.status = Status.SUCCESS + self.status = Status.FAILURE self.data["error"] = error def json(self) -> dict: diff --git a/simple/stats/runner.py b/simple/stats/runner.py index 37f891cf..9625d675 100644 --- a/simple/stats/runner.py +++ b/simple/stats/runner.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures from datetime import datetime from datetime import timezone from enum import StrEnum import json import logging import os +import threading from typing import Optional import fs.path as fspath @@ -47,6 +49,7 @@ from stats.events_importer import EventsImporter from stats.importer import Importer from stats.jsonld_exporter import export_to_jsonld +from stats.jsonld_stream_db import JsonLdStreamDb from stats.mcf_importer import McfImporter import stats.nl as nl from stats.nodes import Nodes @@ -159,7 +162,6 @@ def __init__( _check_not_overlapping(input_store, output_store) self.all_stores.append(output_store) self.output_dir = output_store.as_dir() - self.nl_dir = self.output_dir.open_dir(constants.NL_DIR_NAME) self.process_dir = self.output_dir.open_dir(constants.PROCESS_DIR_NAME) # Reporter. @@ -169,6 +171,7 @@ def __init__( self.nodes = Nodes(self.config) self.db = None self.db_cache = None + self.trigger_workflow_info = None def run(self): # Check if blue-green is enabled @@ -213,6 +216,11 @@ def run(self): store.close() logging.info("File storage closed.") + # Auto-trigger workflow now that all data is guaranteed to be exported and written to GCS + if self.trigger_workflow_info: + gcs_pattern, import_name = self.trigger_workflow_info + trigger_ingestion_workflow(gcs_pattern, import_name) + except Exception as e: logging.exception("Error updating stats") self.reporter.report_failure(error=str(e)) @@ -517,6 +525,7 @@ def _run_local_sqlite_build_import(self): logging.warning(f"Failed to cleanup local database: {e}") def _generate_nl_artifacts(self): + nl_dir = self.output_dir.open_dir(constants.NL_DIR_NAME) triples: list[Triple] = [] topic_triples = self.db.select_triples_by_subject_type(sc.TYPE_TOPIC) sv_triples = self.db.select_triples_by_subject_type( @@ -524,14 +533,14 @@ def _generate_nl_artifacts(self): triples = topic_triples + sv_triples # Generate sentences. - nl.generate_nl_sentences(triples, self.nl_dir) + nl.generate_nl_sentences(triples, nl_dir) # If generating topics, fetch svpg triples as well and generate topic cache if topic_triples: sv_peer_group_triples = self.db.select_triples_by_subject_type( sc.TYPE_STAT_VAR_PEER_GROUP) topic_cache_triples = topic_triples + sv_peer_group_triples - nl.generate_topic_cache(topic_cache_triples, self.nl_dir) + nl.generate_topic_cache(topic_cache_triples, nl_dir) def _generate_svg_hierarchy(self): if self.mode == RunMode.MAIN_DC: @@ -607,11 +616,9 @@ def _check_if_special_file(self, file: File) -> bool: return True return False - def _run_all_data_imports(self): + def _find_and_filter_input_files(self) -> tuple[list[File], list[File]]: + """Discovers, filters, sorts, and returns matched CSV and MCF files.""" input_files: list[File] = [] - input_csv_files: list[File] = [] - input_mcf_files: list[File] = [] - for input_store in self.input_stores: if input_store.isdir(): input_files.extend(input_store.as_dir().all_files( @@ -619,38 +626,75 @@ def _run_all_data_imports(self): else: input_files.append(input_store.as_file()) - for input_file in input_files: - if _ARCHIVES_DIR_NAME in input_file.path.split("/"): + csv_files: list[File] = [] + mcf_files: list[File] = [] + + for file in input_files: + if _ARCHIVES_DIR_NAME in file.path.split("/"): continue - if self._check_if_special_file(input_file): + if self._check_if_special_file(file): continue - if match(input_file, "*.csv"): - input_csv_files.append(input_file) - if match(input_file, "*.mcf"): - input_mcf_files.append(input_file) + if match(file, "*.csv"): + csv_files.append(file) + elif match(file, "*.mcf"): + mcf_files.append(file) - # Sort input files alphabetically. - input_csv_files.sort(key=lambda f: f.full_path()) - input_mcf_files.sort(key=lambda f: f.full_path()) + # Sort alphabetically to guarantee consistent order + csv_files.sort(key=lambda f: f.full_path()) + mcf_files.sort(key=lambda f: f.full_path()) + return csv_files, mcf_files - logging.info(f"Found {len(input_csv_files)} csv files to import") - logging.info(f"Found {len(input_mcf_files)} mcf files to import") - logging.info("Matched files to process: %s", - [f.full_path() for f in input_csv_files + input_mcf_files]) + def _run_all_data_imports(self): + """Orchestrates file scanning, thread-pool configuration, and file ingestion.""" + csv_files, mcf_files = self._find_and_filter_input_files() - self.reporter.report_started(import_files=list(input_csv_files + - input_mcf_files)) - for input_csv_file in input_csv_files: - self._run_single_import(input_csv_file) - for input_mcf_file in input_mcf_files: - self._run_single_mcf_import(input_mcf_file) + logging.info("Found %d CSV files to import", len(csv_files)) + logging.info("Found %d MCF files to import", len(mcf_files)) + logging.info("Matched files to process: %s", + [f.full_path() for f in csv_files + mcf_files]) + + self.reporter.report_started(import_files=list(csv_files + mcf_files)) + + self._completed_files_count = 0 + self._total_files_count = len(csv_files) + len(mcf_files) + self._counter_lock = threading.Lock() + + if self.mode == RunMode.DCP_BRIDGE: + num_threads = min(32, self._total_files_count or 1) + logging.info("Starting parallel ingestion of data files with %d threads", + num_threads) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_threads) as executor: + futures = [] + for file in csv_files: + futures.append(executor.submit(self._run_single_import, file)) + for file in mcf_files: + futures.append(executor.submit(self._run_single_mcf_import, file)) + + # Wait for completion and raise any thread exceptions + for future in concurrent.futures.as_completed(futures): + future.result() + else: + for file in csv_files: + self._run_single_import(file) + for file in mcf_files: + self._run_single_mcf_import(file) + + def _log_file_progress(self, file_prefix: str, file: File): + """Increments file progress counter thread-safely and logs standard progress line.""" + with self._counter_lock: + self._completed_files_count += 1 + current_count = self._completed_files_count + logging.info("[%d/%d] %s: %s", current_count, self._total_files_count, + file_prefix, file) def _run_single_import(self, input_file: File): - logging.info("Importing file: %s", input_file) + self._log_file_progress("Importing CSV file", input_file) self._create_importer(input_file).do_import() def _run_single_mcf_import(self, input_mcf_file: File): - logging.info("Importing MCF file: %s", input_mcf_file) + self._log_file_progress("Importing MCF file", input_mcf_file) self._create_mcf_importer(input_mcf_file, self.output_dir, self.mode == RunMode.MAIN_DC).do_import() @@ -712,61 +756,23 @@ def _create_importer(self, input_file: File) -> Importer: f"Unsupported import type: {import_type} ({input_file.full_path()})") def _run_imports_and_export_jsonld(self): - # Force local SQLite DB for staging data in dcpbridge mode - logging.info("Forcing local SQLite DB for staging data in dcpbridge mode") - sqlite_file = self.output_dir.open_file("staging.db") - db_cfg = create_sqlite_config(sqlite_file) - self.db = create_and_update_db(db_cfg) - - # Clear tables if needed - self.db.maybe_clear_before_import() + logging.info( + "Initializing JsonLdStreamDb to stream JSON-LD directly to GCS/Disk") + self.db = JsonLdStreamDb(self.output_dir, self.import_names, self.nodes) # Run data imports (CSV and MCF) self._run_all_data_imports() - # Generate triples from nodes + # Generate triples from nodes and write directly triples = self.nodes.triples() self.db.insert_triples(triples) - # Generate SVG hierarchy - self._generate_svg_hierarchy() - - # Generate NL artifacts - self._generate_nl_artifacts() - - # Write import info - self.db.insert_import_info(status=ImportStatus.SUCCESS) - - # Export to JSON-LD - jsonld_dir = self.output_dir.open_dir("jsonld") - - # Create a unique subfolder based on import name and timestamp for parallel runs - import_name = None - import_names = self.import_names - if isinstance(import_names, list): - if import_names == [constants.ALL_IMPORTS]: - import_name = constants.ALL_IMPORTS - else: - import_name = "_".join(import_names) - - # TODO(gmechali): Remove the fallbacks. - import_name = import_name or self.config.data.get( - "importName") or "default_import_name" - if import_name and "/" in import_name: - import_name = import_name.replace("/", "_") - - timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f") - unique_dir_name = f"{import_name}_{timestamp}" - unique_jsonld_dir = jsonld_dir.open_dir(unique_dir_name) - - self.db.commit() - export_to_jsonld(self.db, unique_jsonld_dir) - # Auto-trigger workflow if output is on GCS - output_path = unique_jsonld_dir.full_path() + output_path = self.db.jsonld_dir.full_path() + import_name = self.db.import_name if os.getenv("INGESTION_WORKFLOW_NAME") and output_path.startswith("gs://"): gcs_pattern = f"{output_path.rstrip('/')}/*.jsonld" - trigger_ingestion_workflow(gcs_pattern, import_name) + self.trigger_workflow_info = (gcs_pattern, import_name) else: logging.info( "Output is local or workflow is missing, skipping auto-trigger of ingestion workflow. Please upload files to GCS and trigger manually." diff --git a/simple/tests/stats/jsonld_stream_db_test.py b/simple/tests/stats/jsonld_stream_db_test.py new file mode 100644 index 00000000..03f52a86 --- /dev/null +++ b/simple/tests/stats/jsonld_stream_db_test.py @@ -0,0 +1,317 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +import unittest +from unittest import mock + +import pandas as pd +from stats.data import Triple +from stats.jsonld_stream_db import JsonLdStreamDb +from util.filesystem import create_store + + +class TestJsonLdStreamDb(unittest.TestCase): + + def setUp(self): + self.mock_config = mock.MagicMock() + self.mock_config.custom_id_namespace.return_value = "custom" + self.mock_config.data = {"importName": "test_import"} + + self.mock_nodes = mock.MagicMock() + self.mock_nodes.config = self.mock_config + self.mock_nodes.provenances = {} + + def test_directory_creation(self): + with tempfile.TemporaryDirectory() as temp_dir: + temp_store = create_store(temp_dir) + + db = JsonLdStreamDb(output_dir=temp_store.as_dir(), + import_names=["test_import"], + nodes=self.mock_nodes) + + # The jsonld folder should be created under the output dir + self.assertTrue(os.path.isdir(os.path.join(temp_dir, "jsonld"))) + + # The db path should contain test_import + self.assertEqual(db.import_name, "test_import") + self.assertTrue(db.jsonld_dir.full_path().startswith( + os.path.join(temp_dir, "jsonld", "test_import_"))) + + def test_insert_observations_and_triples(self): + with tempfile.TemporaryDirectory() as temp_dir: + temp_store = create_store(temp_dir) + db = JsonLdStreamDb(output_dir=temp_store.as_dir(), + import_names=["test_import"], + nodes=self.mock_nodes) + + # Insert observations + df = pd.DataFrame([("e1", "v1", "2026", "100", "p1", "", "", "", "", "")], + columns=[ + "entity", "variable", "date", "value", "provenance", + "unit", "scaling_factor", "measurement_method", + "observation_period", "properties" + ]) + mock_file = mock.Mock() + db.insert_observations(df, mock_file) + self.assertEqual(len(db._obs_records), 1) + self.assertEqual(db._obs_records[0][0], "e1") + + # Insert triples + triples = [Triple("sub1", "pred1", object_value="val1")] + db.insert_triples(triples) + self.assertEqual(len(db._triples), 1) + + def test_commit_and_close_local(self): + with tempfile.TemporaryDirectory() as temp_dir: + temp_store = create_store(temp_dir) + db = JsonLdStreamDb(output_dir=temp_store.as_dir(), + import_names=["test_import"], + nodes=self.mock_nodes) + + # Insert observations + df = pd.DataFrame([("e1", "v1", "2026", "100", "p1", "", "", "", "", "")], + columns=[ + "entity", "variable", "date", "value", "provenance", + "unit", "scaling_factor", "measurement_method", + "observation_period", "properties" + ]) + mock_file = mock.Mock() + db.insert_observations(df, mock_file) + + # Insert triples + triples = [Triple("sub1", "typeOf", object_id="StatisticalVariable")] + db.insert_triples(triples) + + db.commit_and_close() + + # Shards should be written directly to the target unique directory + target_dir_path = db.jsonld_dir.full_path() + obs_shard = os.path.join(target_dir_path, "observation-00000.jsonld") + node_shard = os.path.join(target_dir_path, "node-00000.jsonld") + + self.assertTrue(os.path.exists(obs_shard)) + self.assertTrue(os.path.exists(node_shard)) + + # Validate observation shard content + with open(obs_shard, "r") as f: + data = json.load(f) + self.assertIn("@graph", data) + graph = data["@graph"] + self.assertEqual(len(graph), 1) + self.assertEqual(graph[0]["dcid:observationAbout"]["@id"], "dcid:e1") + self.assertEqual(graph[0]["dcid:value"], 100) + + # Validate node shard content + with open(node_shard, "r") as f: + data = json.load(f) + self.assertIn("@graph", data) + graph = data["@graph"] + self.assertEqual(len(graph), 1) + self.assertEqual(graph[0]["@id"], "dcid:sub1") + self.assertEqual(graph[0]["@type"], "dcid:StatisticalVariable") + + @mock.patch("google.cloud.storage.Client") + def test_commit_and_close_gcs(self, mock_storage_client): + # Setup GCS mock + mock_client_instance = mock_storage_client.return_value + mock_bucket = mock_client_instance.bucket.return_value + mock_blob = mock_bucket.blob.return_value + + with tempfile.TemporaryDirectory() as temp_dir: + temp_store = create_store(temp_dir) + + # Mock the output dir as a GCS path + mock_output_dir = mock.MagicMock() + mock_output_dir.open_dir.return_value.open_dir.return_value.full_path.return_value = "gs://my-bucket/ingestion/test" + mock_output_dir.open_dir.return_value.open_dir.return_value.isdir.return_value = False + + db = JsonLdStreamDb(output_dir=mock_output_dir, + import_names=["test_import"], + nodes=self.mock_nodes) + + # Insert observation + df = pd.DataFrame([("e1", "v1", "2026", "100", "p1", "", "", "", "", "")], + columns=[ + "entity", "variable", "date", "value", "provenance", + "unit", "scaling_factor", "measurement_method", + "observation_period", "properties" + ]) + mock_file = mock.Mock() + db.insert_observations(df, mock_file) + + db.commit_and_close() + + # Verify storage bucket call was made + mock_storage_client.assert_called_once() + mock_client_instance.bucket.assert_called_with("my-bucket") + + # Verify upload blob calls + mock_bucket.blob.assert_called_with( + "ingestion/test/observation-00000.jsonld") + mock_blob.upload_from_filename.assert_called_once() + + def test_node_fast_vs_rdflib_parity(self): + """Rigorous parity test: Compares fast path output with rdflib path output.""" + from stats.jsonld_stream_db import _write_node_shard_fast + from stats.jsonld_stream_db import _write_node_shard_rdflib + + complex_triples = [ + Triple(subject_id="sub1", + predicate="typeOf", + object_id="StatisticalVariable"), + Triple(subject_id="sub1", predicate="typeOf", object_id="Thing"), + Triple(subject_id="sub1", predicate="name", object_value="Test Node"), + # Multi-valued properties + Triple(subject_id="sub1", + predicate="alternateName", + object_value="Alias A"), + Triple(subject_id="sub1", + predicate="alternateName", + object_value="Alias B"), + # Duplicate values for testing deduplication + Triple(subject_id="sub1", + predicate="alternateName", + object_value="Alias A"), + Triple(subject_id="sub1", predicate="intValue", object_value=99), + # References vs Values + Triple(subject_id="sub1", predicate="memberOf", object_id="groupA"), + # Number types + Triple(subject_id="sub1", predicate="countValue", object_value=15.8), + Triple(subject_id="sub1", predicate="intValue", object_value=99), + # External URL predicate/object + Triple(subject_id="sub1", + predicate="http://schema.org/url", + object_id="https://example.org"), + ] + from stats.jsonld_exporter import DCID_URL + ns_map = {"dcid": DCID_URL} + + with tempfile.TemporaryDirectory() as temp_dir_fast, \ + tempfile.TemporaryDirectory() as temp_dir_rdflib: + + _write_node_shard_fast((complex_triples, 0, temp_dir_fast, ns_map)) + _write_node_shard_rdflib((complex_triples, 0, temp_dir_rdflib, ns_map)) + + fast_file = os.path.join(temp_dir_fast, "node-00000.jsonld") + rdflib_file = os.path.join(temp_dir_rdflib, "node-00000.jsonld") + + self.assertTrue(os.path.exists(fast_file)) + self.assertTrue(os.path.exists(rdflib_file)) + + with open(fast_file, "r") as f: + fast_json = json.load(f) + with open(rdflib_file, "r") as f: + rdflib_json = json.load(f) + + # Helper to normalize a JSON-LD graph for strict comparison + def normalize_graph(graph): + normalized = {} + for item in graph["@graph"]: + item_id = item["@id"] + normalized_item = {} + for k, v in item.items(): + if k == "@id": + continue + # If value is list, sort it to ensure order-independence + if isinstance(v, list): + sorted_v = sorted( + v, + key=lambda x: x["@id"] + if isinstance(x, dict) and "@id" in x else str(x)) + normalized_item[k] = sorted_v + else: + normalized_item[k] = v + normalized[item_id] = normalized_item + return normalized + + self.assertEqual(normalize_graph(fast_json), normalize_graph(rdflib_json)) + + def test_observation_parsing_edge_cases(self): + """Rigorous data type & properties parsing check to ensure zero property loss.""" + from stats.jsonld_stream_db import _write_observation_shard + + # Custom properties as nested JSON + custom_props = json.dumps({ + "customIntProp": 42, + "dcid:customStrProp": "customVal", + "http://schema.org/url": "https://test-prop.org" + }) + + # Rows with edge-case numbers and strings + chunk = [ + # entity, variable, date, value, provenance, unit, scaling_factor, mmethod, period, props + ("country/ALB", "v1", "2026", "99", "p1", "unit1", "100", "m1", "P1Y", + custom_props), + ("country/USA", "v1", "2026.5", "123.45", "p1", None, "10.5", None, + None, None), + ("country/IND", "v1", "2026-06", "Unavailable", "p1", None, None, None, + None, None), + ] + + ns_map = {"dcid": "https://datacommons.org/ontology/"} + prov_urls = {"p1": "http://my-provenance.org/url"} + + with tempfile.TemporaryDirectory() as temp_dir: + _write_observation_shard((chunk, 0, temp_dir, ns_map, prov_urls)) + + shard_file = os.path.join(temp_dir, "observation-00000.jsonld") + self.assertTrue(os.path.exists(shard_file)) + + with open(shard_file, "r") as f: + data = json.load(f) + + self.assertIn("@graph", data) + graph = data["@graph"] + self.assertEqual(len(graph), 3) + + # 1. Verify first observation (Int types, Custom Props) + obs1 = [ + o for o in graph + if o["dcid:observationAbout"]["@id"] == "dcid:country/ALB" + ][0] + self.assertEqual(obs1["dcid:value"], 99) + self.assertEqual(obs1["dcid:observationDate"], 2026) + self.assertEqual(obs1["dcid:scalingFactor"], 100) + self.assertEqual(obs1["dcid:provenanceUrl"], + "http://my-provenance.org/url") + self.assertEqual(obs1["dcid:observationPeriod"], "P1Y") + + # Verify custom properties from JSON string + self.assertEqual(obs1["dcid:customIntProp"], 42) + self.assertEqual(obs1["dcid:customStrProp"], "customVal") + self.assertEqual(obs1["http://schema.org/url"], "https://test-prop.org") + + # 2. Verify second observation (Float types) + obs2 = [ + o for o in graph + if o["dcid:observationAbout"]["@id"] == "dcid:country/USA" + ][0] + self.assertEqual(obs2["dcid:value"], 123.45) + self.assertEqual(obs2["dcid:observationDate"], 2026.5) + self.assertEqual(obs2["dcid:scalingFactor"], 10.5) + + # 3. Verify third observation (Non-numeric value & Date string) + obs3 = [ + o for o in graph + if o["dcid:observationAbout"]["@id"] == "dcid:country/IND" + ][0] + self.assertEqual(obs3["dcid:value"], "Unavailable") + self.assertEqual(obs3["dcid:observationDate"], "2026-06") + + +if __name__ == "__main__": + unittest.main() diff --git a/simple/tests/stats/runner_test.py b/simple/tests/stats/runner_test.py index 7a863870..1aca65a3 100644 --- a/simple/tests/stats/runner_test.py +++ b/simple/tests/stats/runner_test.py @@ -235,6 +235,57 @@ def test_with_redis_db_cache_schema_update(self): # Redis cache should NOT be cleared in schema update mode. self.assertEqual(1, len(fake_redis.keys("*"))) + def test_dcp_bridge(self): + self.maxDiff = None + with tempfile.TemporaryDirectory() as temp_dir: + input_dir = os.path.join(_INPUT_DIR, "input_dir_driven") + dc_client.get_property_of_entities = mock.MagicMock(return_value={}) + + Runner( + config_file_path=None, + input_dir_path=input_dir, + output_dir_path=temp_dir, + mode=RunMode.DCP_BRIDGE, + ).run() + + # Verify that NO SQLite database file is created + db_path = os.path.join(temp_dir, "datacommons.db") + self.assertFalse(os.path.exists(db_path)) + + # Verify that NO nl directory is created (since GCS/local embeddings are stripped/disabled) + nl_dir = os.path.join(temp_dir, "nl") + self.assertFalse(os.path.exists(nl_dir)) + + # Verify that a jsonld directory is created + jsonld_dir = os.path.join(temp_dir, "jsonld") + self.assertTrue(os.path.exists(jsonld_dir)) + + # Find the subdirectory inside jsonld/ + subdirs = os.listdir(jsonld_dir) + # There should be exactly 1 folder in jsonld/ + self.assertEqual(len(subdirs), 1) + timestamped_dir = os.path.join(jsonld_dir, subdirs[0]) + self.assertTrue(os.path.isdir(timestamped_dir)) + + # Ensure the timestamped directory has files + shard_files = os.listdir(timestamped_dir) + self.assertGreater(len(shard_files), 0) + + # Check that we have both node and observation shard files + node_shards = [f for f in shard_files if f.startswith("node-")] + obs_shards = [f for f in shard_files if f.startswith("observation-")] + + self.assertGreater(len(node_shards), 0) + self.assertGreater(len(obs_shards), 0) + + # Verify that files are valid JSON-LD + for filename in shard_files: + filepath = os.path.join(timestamped_dir, filename) + self.assertTrue(filename.endswith(".jsonld")) + with open(filepath, "r") as f: + data = json.load(f) + self.assertTrue(isinstance(data, (dict, list))) + def test_read_configs_from_subdirs(self): with tempfile.TemporaryDirectory() as temp_dir: # Create subdirectories