| |
|
|
| import time |
| import uuid |
| import hashlib |
| import json |
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import Any, Dict, List, Optional |
| import logging |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| DEFAULT_TTL_HOURS = 24 |
| SECONDS_PER_HOUR = 3600 |
| MAX_RETRY_ATTEMPTS = 5 |
| BASE_RETRY_DELAY_SECONDS = 1 |
|
|
|
|
| class NetworkState(Enum): |
| ONLINE = "ONLINE" |
| OFFLINE = "OFFLINE" |
| HIGH_LATENCY = "HIGH_LATENCY" |
|
|
|
|
| class ActionStatus(Enum): |
| PENDING = "PENDING" |
| SYNCED = "SYNCED" |
| EXPIRED = "EXPIRED" |
| FAILED = "FAILED" |
| RETRY_SCHEDULED = "RETRY_SCHEDULED" |
|
|
|
|
| class SyncResult(Enum): |
| SUCCESS = "SUCCESS" |
| ABORT_NO_NET = "ABORT_NO_NET" |
| PARTIAL_SUCCESS = "PARTIAL_SUCCESS" |
| ALL_FAILED = "ALL_FAILED" |
|
|
|
|
| @dataclass |
| class SyncSummary: |
| result: SyncResult |
| synced_count: int = 0 |
| failed_count: int = 0 |
| expired_count: int = 0 |
| error_messages: List[str] = None |
|
|
| def __post_init__(self): |
| if self.error_messages is None: |
| self.error_messages = [] |
|
|
|
|
| class LocalDatabase(ABC): |
| @abstractmethod |
| def insert(self, table: str, data: dict) -> None: |
| pass |
|
|
| @abstractmethod |
| def query(self, sql: str) -> List[dict]: |
| pass |
|
|
| @abstractmethod |
| def update(self, record_id: str, **kwargs) -> None: |
| pass |
|
|
| @abstractmethod |
| def get_state_snapshot(self) -> dict: |
| """Get current local state for conflict detection.""" |
| pass |
|
|
|
|
| class ApiClient(ABC): |
| @abstractmethod |
| def post(self, endpoint: str, json: dict) -> dict: |
| pass |
|
|
|
|
| class SyncQueue: |
| """Manages offline action queue with sync capabilities.""" |
|
|
| def __init__( |
| self, |
| local_db: LocalDatabase, |
| api_client: ApiClient, |
| ttl_hours: int = DEFAULT_TTL_HOURS |
| ): |
| self._db = local_db |
| self._api_client = api_client |
| self._ttl_hours = ttl_hours |
| self._retry_counts: Dict[str, int] = {} |
|
|
| def get_local_state_hash(self) -> str: |
| """Generate a hash of the current local state for conflict detection.""" |
| try: |
| state_snapshot = self._db.get_state_snapshot() |
| state_json = json.dumps(state_snapshot, sort_keys=True) |
| return hashlib.sha256(state_json.encode('utf-8')).hexdigest() |
| except Exception as e: |
| logger.exception("Failed to generate local state hash") |
| return "" |
|
|
| def enqueue_action( |
| self, |
| action_type: str, |
| payload: dict, |
| priority: int = 1 |
| ) -> str: |
| """ |
| Stores an action locally with a Time-To-Live (TTL). |
| If the action acts on old data, we tag it for conflict resolution. |
| |
| Returns: |
| The action ID for tracking. |
| """ |
| if not action_type or not isinstance(action_type, str): |
| raise ValueError("action_type must be a non-empty string") |
|
|
| if not isinstance(payload, dict): |
| raise ValueError("payload must be a dictionary") |
|
|
| action_id = str(uuid.uuid4()) |
| action_item = { |
| "id": action_id, |
| "timestamp": time.time(), |
| "type": action_type, |
| "payload": self._sanitize_payload(payload), |
| "priority": priority, |
| "status": ActionStatus.PENDING.value, |
| "ttl_expiry": time.time() + (self._ttl_hours * SECONDS_PER_HOUR), |
| "device_state_hash": self.get_local_state_hash(), |
| "retry_count": 0 |
| } |
|
|
| |
| try: |
| self._db.insert("offline_queue", action_item) |
| logger.info(f"Action {action_id} queued locally") |
| return action_id |
| except Exception as e: |
| logger.exception(f"Failed to enqueue action: {str(e)}") |
| raise |
|
|
| def _sanitize_payload(self, payload: dict) -> dict: |
| """Remove sensitive fields before storage/transmission.""" |
| sensitive_keys = {'password', 'token', |
| 'secret', 'api_key', 'credential'} |
| return { |
| k: v for k, v in payload.items() |
| if k.lower() not in sensitive_keys |
| } |
|
|
| def attempt_sync(self, current_network_state: NetworkState) -> SyncSummary: |
| """ |
| Only attempts sync if network is stable. |
| |
| Returns: |
| SyncSummary with results of the sync operation. |
| """ |
| if current_network_state == NetworkState.OFFLINE: |
| logger.info("Sync aborted: network offline") |
| return SyncSummary(result=SyncResult.ABORT_NO_NET) |
|
|
| try: |
| pending_actions = self._db.query( |
| "SELECT * FROM offline_queue WHERE status='PENDING' OR status='RETRY_SCHEDULED'" |
| ) |
| except Exception as e: |
| logger.exception("Failed to query pending actions") |
| return SyncSummary( |
| result=SyncResult.ALL_FAILED, |
| error_messages=[f"Database query failed: {str(e)}"] |
| ) |
|
|
| if not pending_actions: |
| logger.info("No pending actions to sync") |
| return SyncSummary(result=SyncResult.SUCCESS) |
|
|
| |
| pending_actions.sort( |
| key=lambda x: (-x.get('priority', 1), x['timestamp'])) |
|
|
| synced_count = 0 |
| failed_count = 0 |
| expired_count = 0 |
| error_messages = [] |
|
|
| for action in pending_actions: |
| action_id = action['id'] |
|
|
| |
| if time.time() > action['ttl_expiry']: |
| self._db.update(action_id, status=ActionStatus.EXPIRED.value) |
| expired_count += 1 |
| logger.warning(f"Action {action_id} expired") |
| continue |
|
|
| |
| sync_success = self._sync_action(action) |
| if sync_success: |
| synced_count += 1 |
| else: |
| failed_count += 1 |
| error_messages.append(f"Action {action_id} failed to sync") |
|
|
| |
| if failed_count == 0: |
| result = SyncResult.SUCCESS |
| elif synced_count == 0: |
| result = SyncResult.ALL_FAILED |
| else: |
| result = SyncResult.PARTIAL_SUCCESS |
|
|
| return SyncSummary( |
| result=result, |
| synced_count=synced_count, |
| failed_count=failed_count, |
| expired_count=expired_count, |
| error_messages=error_messages |
| ) |
|
|
| def _sync_action(self, action: dict) -> bool: |
| """Attempt to sync a single action to the cloud.""" |
| action_id = action['id'] |
|
|
| |
| sync_payload = { |
| "id": action_id, |
| "type": action['type'], |
| "payload": action['payload'], |
| "timestamp": action['timestamp'], |
| "device_state_hash": action.get('device_state_hash', '') |
| } |
|
|
| try: |
| self._api_client.post("/sync", json=sync_payload) |
| self._db.update(action_id, status=ActionStatus.SYNCED.value) |
| logger.info(f"Action {action_id} synced successfully") |
| return True |
|
|
| except TimeoutError: |
| self._schedule_retry(action_id) |
| return False |
|
|
| except ConnectionError as e: |
| logger.error(f"Connection error syncing {action_id}: {str(e)}") |
| self._schedule_retry(action_id) |
| return False |
|
|
| except Exception as e: |
| logger.exception(f"Unexpected error syncing {action_id}") |
| retry_count = self._retry_counts.get(action_id, 0) |
| if retry_count >= MAX_RETRY_ATTEMPTS: |
| self._db.update(action_id, status=ActionStatus.FAILED.value) |
| logger.error( |
| f"Action {action_id} permanently failed after {retry_count} retries") |
| else: |
| self._schedule_retry(action_id) |
| return False |
|
|
| def _schedule_retry(self, action_id: str) -> None: |
| """Schedule an action for retry with exponential backoff.""" |
| current_count = self._retry_counts.get(action_id, 0) |
|
|
| if current_count >= MAX_RETRY_ATTEMPTS: |
| self._db.update(action_id, status=ActionStatus.FAILED.value) |
| logger.error( |
| f"Action {action_id} exceeded max retries ({MAX_RETRY_ATTEMPTS})" |
| ) |
| return |
|
|
| self._retry_counts[action_id] = current_count + 1 |
| delay = BASE_RETRY_DELAY_SECONDS * (2 ** current_count) |
|
|
| self._db.update( |
| action_id, |
| status=ActionStatus.RETRY_SCHEDULED.value, |
| retry_count=current_count + 1 |
| ) |
|
|
| logger.info( |
| f"Action {action_id} scheduled for retry #{current_count + 1} " |
| f"in {delay} seconds" |
| ) |
|
|