Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import argparse
import asyncio
import logging
import os
import re
import sys

Expand Down Expand Up @@ -294,11 +293,12 @@ async def run_node(port: int, host: str, connect_to: str | None, fund: int, data

# Load existing chain from disk, or start fresh
chain = None
if datadir and os.path.exists(os.path.join(datadir, "data.json")):
if datadir:
try:
from minichain.persistence import load
chain = load(datadir)
logger.info("Restored chain from '%s'", datadir)
from minichain.persistence import load, persistence_exists
if persistence_exists(datadir):
chain = load(datadir)
logger.info("Restored chain from '%s'", datadir)
except FileNotFoundError as e:
logger.warning("Could not load saved chain: %s — starting fresh", e)
except ValueError as e:
Expand Down
224 changes: 139 additions & 85 deletions minichain/persistence.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,56 @@
"""
Chain persistence: save and load the blockchain and state to/from JSON.
Chain persistence: save and load the blockchain and state to/from SQLite.

Design:
- blockchain.json holds the full list of serialised blocks
- state.json holds the accounts dict (includes off-chain credits)
- data.db holds the full chain snapshot, account state, and small metadata.
- legacy data.json snapshots can still be loaded for backward compatibility.

Both files are written atomically (temp → rename) to prevent corruption
on crash. On load, chain integrity is verified before the data is trusted.

Usage:
The public API intentionally stays the same:
from minichain.persistence import save, load

save(blockchain, path="data/")
blockchain = load(path="data/")
"""

from __future__ import annotations

import json
import os
import tempfile
import logging
import copy
import os
import sqlite3
from typing import Any

from .block import Block
from .chain import Blockchain, validate_block_link_and_hash

logger = logging.getLogger(__name__)

_DATA_FILE = "data.json"
_DB_FILE = "data.db"
_LEGACY_DATA_FILE = "data.json"


def persistence_exists(path: str = ".") -> bool:
"""Return True if a SQLite or legacy JSON snapshot exists inside *path*."""
return os.path.exists(os.path.join(path, _DB_FILE)) or os.path.exists(
os.path.join(path, _LEGACY_DATA_FILE)
)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def save(blockchain: Blockchain, path: str = ".") -> None:
"""
Persist the blockchain and account state to a JSON file inside *path*.

Uses atomic write (write-to-temp → rename) with fsync so a crash mid-save
never corrupts the existing file. Chain and state are saved together to
prevent torn snapshots.
"""
def save(blockchain: Blockchain, path: str = ".") -> None:
"""Persist the blockchain and account state to SQLite inside *path*."""
os.makedirs(path, exist_ok=True)
db_path = os.path.join(path, _DB_FILE)

with blockchain._lock: # Thread-safe: hold lock while serialising
with blockchain._lock:
chain_data = [block.to_dict() for block in blockchain.chain]
state_data = copy.deepcopy(blockchain.state.accounts)

snapshot = {
"chain": chain_data,
"state": state_data
}
state_data = json.loads(json.dumps(blockchain.state.accounts))

_atomic_write_json(os.path.join(path, _DATA_FILE), snapshot)
_save_snapshot_to_sqlite(db_path, {"chain": chain_data, "state": state_data})
Comment on lines +49 to +53
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider using copy.deepcopy() instead of JSON round-trip.

The json.loads(json.dumps(...)) pattern works but is less efficient than copy.deepcopy() for in-memory deep copying. However, if the intent is to validate JSON serializability at save time, this is acceptable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@minichain/persistence.py` around lines 49 - 53, The current snapshot uses
json.loads(json.dumps(blockchain.state.accounts)) to deep-copy state which is
less efficient; replace the JSON round-trip with copy.deepcopy to produce an
in-memory deep copy of blockchain.state.accounts while holding blockchain._lock
before calling _save_snapshot_to_sqlite(db_path, {"chain": chain_data, "state":
state_data}); update the import to include copy and use
copy.deepcopy(blockchain.state.accounts) and leave chain_data creation via
[block.to_dict() for block in blockchain.chain] as-is (unless you also want to
validate JSON serializability, in which case keep the round-trip).


logger.info(
"Saved %d blocks and %d accounts to '%s'",
Expand All @@ -63,42 +61,33 @@ def save(blockchain: Blockchain, path: str = ".") -> None:


def load(path: str = ".") -> Blockchain:
"""
Restore a Blockchain from the JSON file inside *path*.
"""Restore a Blockchain from SQLite inside *path* (with legacy JSON fallback)."""
db_path = os.path.join(path, _DB_FILE)
legacy_path = os.path.join(path, _LEGACY_DATA_FILE)

Steps:
1. Load and deserialise blocks from data.json
2. Verify chain integrity (genesis, linkage, hashes)
3. Load account state

Raises:
FileNotFoundError: if data.json is missing.
ValueError: if data is invalid or integrity checks fail.
"""
data_path = os.path.join(path, _DATA_FILE)
snapshot = _read_json(data_path)
if os.path.exists(db_path):
snapshot = _load_snapshot_from_sqlite(db_path)
elif os.path.exists(legacy_path):
snapshot = _read_legacy_json(legacy_path)
else:
raise FileNotFoundError(f"Persistence file not found in '{path}'")

if not isinstance(snapshot, dict):
raise ValueError(f"Invalid snapshot data in '{data_path}'")
raise ValueError(f"Invalid snapshot data in '{path}'")

raw_blocks = snapshot.get("chain")
raw_accounts = snapshot.get("state")

if not isinstance(raw_blocks, list) or not raw_blocks:
raise ValueError(f"Invalid or empty chain data in '{data_path}'")
raise ValueError(f"Invalid or empty chain data in '{path}'")
if not isinstance(raw_accounts, dict):
raise ValueError(f"Invalid accounts data in '{data_path}'")
raise ValueError(f"Invalid accounts data in '{path}'")

blocks = [_deserialize_block(b) for b in raw_blocks]

# --- Integrity verification ---
_verify_chain_integrity(blocks)

# --- Rebuild blockchain properly (no __new__ hack) ---
blockchain = Blockchain() # creates genesis + fresh state
blockchain.chain = blocks # replace with loaded chain

# Restore state
blockchain = Blockchain()
blockchain.chain = blocks
blockchain.state.accounts = raw_accounts
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Line 91 trusts an unauthenticated accounts snapshot.

After _verify_chain_integrity(blocks), Line 91 assigns raw_accounts wholesale. Editing one accounts.account_json row can rewrite balances/nonces on startup without changing any block, and the chain cannot prove that state because mining rewards are applied outside block payloads in main.py, Lines 87-90 and 149-151. Please persist enough information to rebuild state on load, or add a verifiable state root/signature before accepting accounts.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@minichain/persistence.py` around lines 89 - 91, The code assigns raw_accounts
directly into blockchain.state.accounts after calling
_verify_chain_integrity(blocks), which allows tampered account snapshots to be
trusted; instead either rebuild the state from the verified blocks or verify a
signed/state-root against the snapshot: update the load path that currently does
blockchain = Blockchain(); blockchain.chain = blocks; blockchain.state.accounts
= raw_accounts to (a) reconstruct state by replaying blocks/transactions using
the existing state application logic (e.g., the same functions that apply
transactions/mining rewards used in main.py) and then replace
blockchain.state.accounts with the reconstructed state, or (b) persist and load
a verifiable state root/signature alongside the snapshot and check that the
computed state root from the reconstructed state matches before accepting
raw_accounts (use the _verify_chain_integrity and a new compute_state_root or
signature verification routine to compare). Ensure references to Blockchain,
blockchain.state.accounts, raw_accounts, and _verify_chain_integrity are used so
the change replaces the unsafe direct assignment with state reconstruction or
state-root verification.


logger.info(
Expand All @@ -114,14 +103,13 @@ def load(path: str = ".") -> Blockchain:
# Integrity verification
# ---------------------------------------------------------------------------

def _verify_chain_integrity(blocks: list) -> None:

def _verify_chain_integrity(blocks: list[Block]) -> None:
"""Verify genesis, hash linkage, and block hashes."""
# Check genesis
genesis = blocks[0]
if genesis.index != 0 or genesis.hash != "0" * 64:
raise ValueError("Invalid genesis block")

# Check linkage and hashes for every subsequent block
for i in range(1, len(blocks)):
block = blocks[i]
prev = blocks[i - 1]
Expand All @@ -132,46 +120,112 @@ def _verify_chain_integrity(blocks: list) -> None:


# ---------------------------------------------------------------------------
# Helpers
# SQLite helpers
# ---------------------------------------------------------------------------

def _atomic_write_json(filepath: str, data) -> None:
"""Write JSON atomically with fsync for durability."""
dir_name = os.path.dirname(filepath) or "."
fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".tmp")

def _connect(db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
return conn


def _initialize_schema(conn: sqlite3.Connection) -> None:
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS blocks (
height INTEGER PRIMARY KEY,
block_json TEXT NOT NULL
);

CREATE TABLE IF NOT EXISTS accounts (
address TEXT PRIMARY KEY,
account_json TEXT NOT NULL
);

CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
"""
)


def _save_snapshot_to_sqlite(db_path: str, snapshot: dict[str, Any]) -> None:
conn = _connect(db_path)
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
f.flush()
os.fsync(f.fileno()) # Ensure data is on disk
os.replace(tmp_path, filepath) # Atomic rename

# Attempt to fsync the directory so the rename is durable
if hasattr(os, "O_DIRECTORY"):
try:
dir_fd = os.open(dir_name, os.O_RDONLY | os.O_DIRECTORY)
try:
os.fsync(dir_fd)
finally:
os.close(dir_fd)
except OSError:
pass # Directory fsync not supported on all platforms

except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
_initialize_schema(conn)
with conn:
conn.execute("DELETE FROM blocks")
conn.execute("DELETE FROM accounts")
conn.execute("DELETE FROM metadata")

for block in snapshot["chain"]:
conn.execute(
"INSERT INTO blocks (height, block_json) VALUES (?, ?)",
(int(block["index"]), json.dumps(block, sort_keys=True)),
)

for address, account in sorted(snapshot["state"].items()):
conn.execute(
"INSERT INTO accounts (address, account_json) VALUES (?, ?)",
(address, json.dumps(account, sort_keys=True)),
)

conn.execute(
"INSERT INTO metadata (key, value) VALUES (?, ?)",
("chain_length", str(len(snapshot["chain"]))),
)
finally:
conn.close()


def _load_snapshot_from_sqlite(db_path: str) -> dict[str, Any]:
try:
conn = _connect(db_path)
except sqlite3.DatabaseError as exc:
raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc

try:
_initialize_schema(conn)
block_rows = conn.execute(
"SELECT block_json FROM blocks ORDER BY height ASC"
).fetchall()
account_rows = conn.execute(
"SELECT address, account_json FROM accounts ORDER BY address ASC"
).fetchall()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
except sqlite3.DatabaseError as exc:
raise ValueError(f"Invalid SQLite persistence data in '{db_path}'") from exc
finally:
conn.close()

try:
chain = [json.loads(row["block_json"]) for row in block_rows]
state = {
row["address"]: json.loads(row["account_json"])
for row in account_rows
}
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid persisted JSON payload in '{db_path}'") from exc

return {"chain": chain, "state": state}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def _read_json(filepath: str):
if not os.path.exists(filepath):
raise FileNotFoundError(f"Persistence file not found: '{filepath}'")

# ---------------------------------------------------------------------------
# Legacy JSON helpers
# ---------------------------------------------------------------------------


def _read_legacy_json(filepath: str) -> dict[str, Any]:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
Comment on lines +257 to 259
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Minor: Remove redundant mode argument.

The "r" mode is the default for open(). This is a nitpick from static analysis.

♻️ Suggested fix
 def _read_legacy_json(filepath: str) -> dict[str, Any]:
-    with open(filepath, "r", encoding="utf-8") as f:
+    with open(filepath, encoding="utf-8") as f:
         return json.load(f)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _read_legacy_json(filepath: str) -> dict[str, Any]:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
def _read_legacy_json(filepath: str) -> dict[str, Any]:
with open(filepath, encoding="utf-8") as f:
return json.load(f)
🧰 Tools
🪛 Ruff (0.15.7)

[warning] 258-258: Unnecessary mode argument

Remove mode argument

(UP015)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@minichain/persistence.py` around lines 257 - 259, The open call in
_read_legacy_json unnecessarily passes the default mode "r"; remove the explicit
"r" argument so the function uses open(filepath, encoding="utf-8") when reading
JSON, keeping json.load(f) unchanged.



def _deserialize_block(data: dict) -> Block:
"""Reconstruct a Block (including its transactions) from a plain dict."""
# ---------------------------------------------------------------------------
# Block deserialisation
# ---------------------------------------------------------------------------


def _deserialize_block(data: dict[str, Any]) -> Block:
return Block.from_dict(data)
Loading
Loading