|
import os
|
|
import json
|
|
import hashlib
|
|
import logging
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
import pandas as pd
|
|
from huggingface_hub import HfApi
|
|
from apscheduler.schedulers.background import BackgroundScheduler
|
|
from apscheduler.triggers.cron import CronTrigger
|
|
import pytz
|
|
from dotenv import load_dotenv
|
|
import soundfile as sf
|
|
import numpy as np
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
import time
|
|
from sqlalchemy import text
|
|
from database_manager import engine
|
|
import filelock
|
|
import queue
|
|
from typing import Dict, Set, List, Optional, Iterator
|
|
import tempfile
|
|
import shutil
|
|
import gc
|
|
import socket
|
|
from contextlib import contextmanager
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('dataset_sync.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
BASE_DIR = Path(os.getenv('DATASET_BASE_DIR', '/app/datasets'))
|
|
SYNC_STATE_FILE = BASE_DIR / '.sync_state.json'
|
|
STATS_FILE = BASE_DIR / 'stats.json'
|
|
|
|
class DatasetSynchronizer:
|
|
def __init__(self):
|
|
self.base_dir = BASE_DIR
|
|
self.sync_state_file = SYNC_STATE_FILE
|
|
self.hf_token = os.getenv('HF_TOKEN')
|
|
self.repo_id = os.getenv('HF_REPO_ID')
|
|
self.hf_api = HfApi(token=self.hf_token)
|
|
self.sync_state = self._load_sync_state()
|
|
self.stats_file = STATS_FILE
|
|
self.max_workers = int(os.getenv('MAX_UPLOAD_WORKERS', 4))
|
|
self.max_retries = int(os.getenv('MAX_UPLOAD_RETRIES', 3))
|
|
self.verified_files = set()
|
|
self.verified_cache = {}
|
|
self.lock_file = BASE_DIR / '.sync.lock'
|
|
self.lock = filelock.FileLock(str(self.lock_file), timeout=0)
|
|
self.file_queue = queue.Queue()
|
|
self.uploaded_files: Set[str] = set()
|
|
self.failed_files: Dict[str, int] = {}
|
|
self.chunk_size = int(os.getenv('UPLOAD_CHUNK_SIZE', 1024 * 1024))
|
|
self.memory_limit = int(os.getenv('SYNC_MEMORY_LIMIT_MB', 1024)) * 1024 * 1024
|
|
self.network_timeout = int(os.getenv('NETWORK_TIMEOUT', 30))
|
|
self.batch_size = int(os.getenv('UPLOAD_BATCH_SIZE', 10))
|
|
self.recovery_file = BASE_DIR / '.sync_recovery'
|
|
|
|
def _load_sync_state(self):
|
|
"""Load the sync state from file or create new if doesn't exist"""
|
|
if self.sync_state_file.exists():
|
|
with open(self.sync_state_file, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
return {
|
|
'files': {},
|
|
'last_sync': None,
|
|
'sync_count': 0
|
|
}
|
|
|
|
def _save_sync_state(self):
|
|
"""Save the current sync state to file"""
|
|
with open(self.sync_state_file, 'w', encoding='utf-8') as f:
|
|
json.dump(self.sync_state, f, indent=2, ensure_ascii=False)
|
|
|
|
def _calculate_file_hash(self, file_path: str) -> str:
|
|
"""Optimized file hashing using chunks"""
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, "rb") as f:
|
|
for byte_block in iter(lambda: f.read(self.chunk_size), b""):
|
|
sha256_hash.update(byte_block)
|
|
return sha256_hash.hexdigest()
|
|
|
|
def _prepare_parquet_files(self):
|
|
"""Prepare Parquet files before synchronization"""
|
|
try:
|
|
from prepare_parquet import update_parquet_files
|
|
update_parquet_files()
|
|
logger.info("Parquet files updated successfully")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error preparing Parquet files: {str(e)}")
|
|
return False
|
|
|
|
def _is_verified_audio(self, file_path):
|
|
"""Check if an audio file belongs to a verified recording"""
|
|
try:
|
|
|
|
parts = file_path.parts
|
|
lang_code = parts[-3]
|
|
audio_filename = parts[-1]
|
|
|
|
|
|
cache_key = f"{lang_code}:{audio_filename}"
|
|
if cache_key in self.verified_cache:
|
|
return self.verified_cache[cache_key]
|
|
|
|
|
|
with engine.connect() as conn:
|
|
table_exists = conn.execute(text("""
|
|
SELECT EXISTS (
|
|
SELECT FROM information_schema.tables
|
|
WHERE table_name = :table_name
|
|
)
|
|
"""), {"table_name": f"recordings_{lang_code}"}).scalar()
|
|
|
|
if not table_exists:
|
|
logger.debug(f"Table recordings_{lang_code} does not exist, skipping verification check")
|
|
self.verified_cache[cache_key] = False
|
|
return False
|
|
|
|
|
|
result = conn.execute(text(f"""
|
|
SELECT EXISTS (
|
|
SELECT 1
|
|
FROM recordings_{lang_code}
|
|
WHERE audio_filename = :filename
|
|
AND status = 'verified'
|
|
)
|
|
"""), {"filename": audio_filename}).scalar()
|
|
|
|
|
|
self.verified_cache[cache_key] = bool(result)
|
|
return self.verified_cache[cache_key]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking verification status for {file_path}: {str(e)}")
|
|
return False
|
|
|
|
def _get_modified_files(self):
|
|
"""Get list of new or modified files since last sync"""
|
|
if not self._prepare_parquet_files():
|
|
logger.error("Failed to prepare Parquet files, aborting sync")
|
|
return []
|
|
|
|
modified_files = []
|
|
|
|
|
|
skipped_langs = set()
|
|
|
|
|
|
if self.stats_file.exists():
|
|
current_hash = self._calculate_file_hash(self.stats_file)
|
|
stored_hash = self.sync_state['files'].get(str(self.stats_file))
|
|
|
|
if current_hash != stored_hash:
|
|
modified_files.append(self.stats_file)
|
|
self.sync_state['files'][str(self.stats_file)] = current_hash
|
|
|
|
|
|
for lang_dir in self.base_dir.iterdir():
|
|
if not lang_dir.is_dir() or lang_dir.name.startswith('.'):
|
|
continue
|
|
|
|
lang_code = lang_dir.name
|
|
|
|
|
|
with engine.connect() as conn:
|
|
table_exists = conn.execute(text("""
|
|
SELECT EXISTS (
|
|
SELECT FROM information_schema.tables
|
|
WHERE table_name = :table_name
|
|
)
|
|
"""), {"table_name": f"recordings_{lang_code}"}).scalar()
|
|
|
|
if not table_exists:
|
|
skipped_langs.add(lang_code)
|
|
logger.info(f"Skipping language {lang_code} - no recordings table exists")
|
|
continue
|
|
|
|
|
|
parquet_file = lang_dir / f"{lang_dir.name}.parquet"
|
|
if parquet_file.exists():
|
|
current_hash = self._calculate_file_hash(parquet_file)
|
|
stored_hash = self.sync_state['files'].get(str(parquet_file))
|
|
|
|
if current_hash != stored_hash:
|
|
modified_files.append(parquet_file)
|
|
self.sync_state['files'][str(parquet_file)] = current_hash
|
|
|
|
|
|
audio_dir = lang_dir / 'audio'
|
|
if audio_dir.exists():
|
|
for audio_file in audio_dir.glob('[a-z0-9]*_[0-9]*_[0-9]*.wav'):
|
|
|
|
if not self._is_verified_audio(audio_file):
|
|
logger.debug(f"Skipping unverified audio: {audio_file}")
|
|
continue
|
|
|
|
|
|
current_hash = self._calculate_file_hash(audio_file)
|
|
stored_hash = self.sync_state['files'].get(str(audio_file))
|
|
|
|
if current_hash != stored_hash:
|
|
modified_files.append(audio_file)
|
|
self.sync_state['files'][str(audio_file)] = current_hash
|
|
logger.info(f"Added verified audio file to sync: {audio_file}")
|
|
|
|
if skipped_langs:
|
|
logger.info(f"Skipped languages due to missing tables: {', '.join(skipped_langs)}")
|
|
|
|
return modified_files
|
|
|
|
@contextmanager
|
|
def _memory_tracker(self):
|
|
"""Track memory usage during operations"""
|
|
try:
|
|
gc.collect()
|
|
start_mem = self._get_memory_usage()
|
|
yield
|
|
finally:
|
|
gc.collect()
|
|
end_mem = self._get_memory_usage()
|
|
if end_mem - start_mem > self.memory_limit:
|
|
logger.warning(f"Memory usage exceeded limit: {(end_mem - start_mem) / 1024 / 1024:.2f}MB")
|
|
|
|
def _get_memory_usage(self) -> int:
|
|
"""Get current memory usage"""
|
|
import psutil
|
|
process = psutil.Process()
|
|
return process.memory_info().rss
|
|
|
|
def _batch_files(self, files: list) -> Iterator[list]:
|
|
"""Process files in batches to manage memory"""
|
|
for i in range(0, len(files), self.batch_size):
|
|
yield files[i:i + self.batch_size]
|
|
|
|
def _save_recovery_state(self, failed_files: dict):
|
|
"""Save failed uploads for recovery"""
|
|
try:
|
|
with open(self.recovery_file, 'w') as f:
|
|
json.dump(failed_files, f)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save recovery state: {e}")
|
|
|
|
def _load_recovery_state(self) -> dict:
|
|
"""Load failed uploads from previous sync"""
|
|
try:
|
|
if self.recovery_file.exists():
|
|
with open(self.recovery_file) as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load recovery state: {e}")
|
|
return {}
|
|
|
|
def _upload_file_with_retry(self, file_path: str, retry_count: int = 0) -> bool:
|
|
"""Enhanced upload with network timeout and better error handling"""
|
|
socket.setdefaulttimeout(self.network_timeout)
|
|
|
|
try:
|
|
with self._memory_tracker():
|
|
relative_path = str(Path(file_path).relative_to(self.base_dir))
|
|
relative_path = relative_path.replace('\\', '/')
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
|
temp_path = temp_file.name
|
|
shutil.copy2(file_path, temp_path)
|
|
|
|
logger.info(f"Uploading {relative_path} to {self.repo_id} (attempt {retry_count + 1})")
|
|
|
|
self.hf_api.upload_file(
|
|
path_or_fileobj=temp_path,
|
|
path_in_repo=relative_path,
|
|
repo_id=self.repo_id,
|
|
repo_type="dataset"
|
|
)
|
|
|
|
os.unlink(temp_path)
|
|
self.uploaded_files.add(file_path)
|
|
logger.info(f"Successfully uploaded {relative_path}")
|
|
return True
|
|
|
|
except socket.timeout:
|
|
logger.error(f"Network timeout uploading {file_path}")
|
|
if retry_count < self.max_retries:
|
|
time.sleep(2 ** retry_count)
|
|
return self._upload_file_with_retry(file_path, retry_count + 1)
|
|
return False
|
|
|
|
except Exception as e:
|
|
if retry_count < self.max_retries - 1:
|
|
logger.warning(f"Upload failed for {file_path}, retrying... ({retry_count + 1}/{self.max_retries})")
|
|
time.sleep(2 ** retry_count)
|
|
return self._upload_file_with_retry(file_path, retry_count + 1)
|
|
else:
|
|
logger.error(f"Error uploading {file_path} after {self.max_retries} attempts: {str(e)}")
|
|
self.failed_files[file_path] = retry_count + 1
|
|
return False
|
|
finally:
|
|
socket.setdefaulttimeout(None)
|
|
|
|
def _parallel_upload(self, files: List[Path]) -> bool:
|
|
"""Upload multiple files in parallel with improved error handling"""
|
|
successful = True
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
future_to_file = {
|
|
executor.submit(self._upload_file_with_retry, str(file_path)): file_path
|
|
for file_path in files
|
|
}
|
|
|
|
for future in as_completed(future_to_file):
|
|
file_path = future_to_file[future]
|
|
try:
|
|
if not future.result():
|
|
successful = False
|
|
logger.error(f"Failed to upload: {file_path}")
|
|
except Exception as e:
|
|
successful = False
|
|
logger.error(f"Unexpected error uploading {file_path}: {str(e)}")
|
|
|
|
return successful
|
|
|
|
def is_syncing(self):
|
|
"""Check if a sync is in progress"""
|
|
try:
|
|
|
|
with self.lock:
|
|
return False
|
|
except filelock.Timeout:
|
|
return True
|
|
|
|
def sync_dataset(self) -> bool:
|
|
"""Improved sync with recovery and resource management"""
|
|
try:
|
|
with self.lock:
|
|
|
|
recovery_state = self._load_recovery_state()
|
|
if recovery_state:
|
|
logger.info(f"Found {len(recovery_state)} failed uploads from previous sync")
|
|
|
|
modified_files = self._get_modified_files()
|
|
modified_files.extend(recovery_state.keys())
|
|
|
|
if not modified_files:
|
|
return True
|
|
|
|
|
|
for batch in self._batch_files(modified_files):
|
|
with self._memory_tracker():
|
|
if not self._parallel_upload(batch):
|
|
self._save_recovery_state(self.failed_files)
|
|
return False
|
|
|
|
|
|
self._update_sync_state()
|
|
if self.recovery_file.exists():
|
|
self.recovery_file.unlink()
|
|
|
|
return len(self.failed_files) == 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Sync error: {e}")
|
|
self._save_recovery_state(self.failed_files)
|
|
return False
|
|
finally:
|
|
|
|
self.uploaded_files.clear()
|
|
self.failed_files.clear()
|
|
|
|
def _update_sync_state(self):
|
|
"""Update sync state with proper cleanup"""
|
|
try:
|
|
|
|
for file_path in self.uploaded_files:
|
|
self.sync_state['files'][str(file_path)] = self._calculate_file_hash(str(file_path))
|
|
|
|
self.sync_state['last_sync'] = datetime.now().isoformat()
|
|
self.sync_state['sync_count'] += 1
|
|
|
|
|
|
temp_state_file = str(self.sync_state_file) + '.tmp'
|
|
with open(temp_state_file, 'w') as f:
|
|
json.dump(self.sync_state, f)
|
|
os.replace(temp_state_file, self.sync_state_file)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update sync state: {e}")
|
|
raise
|
|
|
|
def sync_job():
|
|
"""Function to be called by the scheduler"""
|
|
try:
|
|
synchronizer = DatasetSynchronizer()
|
|
synchronizer.sync_dataset()
|
|
except Exception as e:
|
|
logger.error(f"Error in sync job: {str(e)}")
|
|
|
|
def init_scheduler():
|
|
"""Initialize the scheduler with one-time sync after startup"""
|
|
scheduler = BackgroundScheduler()
|
|
|
|
|
|
sync_hour = int(os.getenv('SYNC_HOUR', '0'))
|
|
sync_minute = int(os.getenv('SYNC_MINUTE', '0'))
|
|
timezone = os.getenv('SYNC_TIMEZONE', 'UTC')
|
|
|
|
|
|
scheduler.add_job(
|
|
sync_job,
|
|
CronTrigger(
|
|
hour=sync_hour,
|
|
minute=sync_minute,
|
|
timezone=pytz.timezone(timezone)
|
|
),
|
|
id='daily_sync',
|
|
name='Daily Dataset Sync'
|
|
)
|
|
|
|
|
|
initial_sync_time = datetime.now() + timedelta(minutes=1)
|
|
scheduler.add_job(
|
|
sync_job,
|
|
'date',
|
|
run_date=initial_sync_time,
|
|
id='initial_sync',
|
|
name='Initial Dataset Sync'
|
|
)
|
|
|
|
scheduler.start()
|
|
logger.info(f"Scheduler initialized with daily sync and one-time initial sync at {initial_sync_time}")
|
|
return scheduler
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
|
|
scheduler = init_scheduler()
|
|
|
|
|
|
try:
|
|
while True:
|
|
pass
|
|
except KeyboardInterrupt:
|
|
scheduler.shutdown()
|
|
logger.info("Scheduler shutdown complete")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Main execution error: {str(e)}")
|
|
raise
|
|
|