scratch_chat / scripts /init_db.py
WebashalarForML's picture
Upload 178 files
330b6e4 verified
#!/usr/bin/env python3
"""Database initialization script for the chat agent application."""
import os
import sys
import argparse
from pathlib import Path
# Add the parent directory to the path so we can import modules
sys.path.append(str(Path(__file__).parent.parent))
from config import config
from migrations.migrate import DatabaseMigrator
def init_database(config_name='development', database_url=None, force=False):
"""Initialize the database with schema and initial data."""
print(f"Initializing database for environment: {config_name}")
# Create migrator
migrator = DatabaseMigrator(
database_url=database_url,
config_name=config_name
)
try:
# Run migrations
print("Running database migrations...")
migrator.migrate()
print("Database initialization completed successfully!")
# Show migration status
print("\nMigration Status:")
migrator.status()
except Exception as e:
print(f"Error initializing database: {e}")
sys.exit(1)
def reset_database(config_name='development', database_url=None):
"""Reset the database by dropping and recreating all tables."""
print(f"WARNING: This will destroy all data in the {config_name} database!")
if config_name == 'production':
print("ERROR: Cannot reset production database for safety reasons.")
sys.exit(1)
confirm = input("Are you sure you want to continue? (yes/no): ")
if confirm.lower() != 'yes':
print("Database reset cancelled.")
return
try:
from app import create_app
from chat_agent.models.base import db
# Create app with specified config
app = create_app(config_name)
with app.app_context():
print("Dropping all tables...")
db.drop_all()
print("Creating all tables...")
db.create_all()
print("Database reset completed!")
except Exception as e:
print(f"Error resetting database: {e}")
sys.exit(1)
def seed_database(config_name='development'):
"""Seed the database with initial test data."""
print(f"Seeding database for environment: {config_name}")
try:
from app import create_app
from chat_agent.models.base import db
from chat_agent.models.chat_session import ChatSession
from chat_agent.models.message import Message
from chat_agent.models.language_context import LanguageContext
import uuid
from datetime import datetime, timedelta
app = create_app(config_name)
with app.app_context():
# Create sample chat session
session_id = uuid.uuid4()
user_id = uuid.uuid4()
session = ChatSession(
id=session_id,
user_id=user_id,
language='python',
message_count=2,
is_active=True
)
db.session.add(session)
# Create sample messages
user_message = Message(
session_id=session_id,
role='user',
content='Hello! Can you help me understand Python functions?',
language='python'
)
db.session.add(user_message)
assistant_message = Message(
session_id=session_id,
role='assistant',
content='Hello! I\'d be happy to help you understand Python functions. A function is a reusable block of code that performs a specific task...',
language='python'
)
db.session.add(assistant_message)
# Create language context
context = LanguageContext(
session_id=session_id,
language='python',
prompt_template='You are a helpful Python programming assistant.',
syntax_highlighting='python'
)
db.session.add(context)
db.session.commit()
print("Database seeded with sample data!")
print(f"Sample session ID: {session_id}")
print(f"Sample user ID: {user_id}")
except Exception as e:
print(f"Error seeding database: {e}")
sys.exit(1)
def main():
"""Main CLI interface for database initialization."""
parser = argparse.ArgumentParser(description="Database initialization tool")
parser.add_argument(
"command",
choices=["init", "reset", "seed", "status"],
help="Database command to run"
)
parser.add_argument(
"--config",
default="development",
choices=["development", "production", "testing"],
help="Configuration environment"
)
parser.add_argument(
"--database-url",
help="Database URL (overrides config)"
)
parser.add_argument(
"--force",
action="store_true",
help="Force operation without confirmation"
)
args = parser.parse_args()
# Run command
if args.command == "init":
init_database(args.config, args.database_url, args.force)
elif args.command == "reset":
reset_database(args.config, args.database_url)
elif args.command == "seed":
seed_database(args.config)
elif args.command == "status":
migrator = DatabaseMigrator(
database_url=args.database_url,
config_name=args.config
)
migrator.status()
if __name__ == "__main__":
main()