Files
salakojoshua1234_gmail.com 9c429caa56 Implement streaming pipeline endpoint for batch processing
- Added `/run/streaming-pipeline` endpoint to process data in batches from either a database or CSV file.
- Introduced `BatchResponse` model for structured responses.
- Updated README with new endpoint details, including parameters and example usage.
- Enhanced error handling and logging during batch processing.
- Ensured data preprocessing and NaN handling in analysis functions.
2025-05-02 14:25:31 +01:00

170 lines
6.6 KiB
Python

"""
Data loading and preprocessing module.
"""
from sqlalchemy import create_engine, text
import pandas as pd
from datetime import datetime
import logging
import os
from .config import DB_CONFIG, TABLE_NAME
logger = logging.getLogger(__name__)
class DataLoader:
def __init__(self):
self.engine = None
self.df = None
self.chunk_size = 10000 # Load 10,000 rows at a time
def connect(self):
"""Establish database connection."""
try:
logger.info("Attempting to connect to database...")
DATABASE_URL = f"postgresql://{DB_CONFIG['user']}:{DB_CONFIG['password']}@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['name']}"
self.engine = create_engine(DATABASE_URL)
with self.engine.connect() as conn:
# First check if table exists
check_table = text(f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '{TABLE_NAME}')")
table_exists = conn.execute(check_table).scalar()
if not table_exists:
logger.error(f"Table {TABLE_NAME} does not exist in the database")
return False
# Get row count
count_query = text(f"SELECT COUNT(*) FROM {TABLE_NAME}")
row_count = conn.execute(count_query).scalar()
logger.info(f"Table {TABLE_NAME} exists with {row_count} rows")
# Get version
result = conn.execute(text("SELECT version();"))
logger.info("Connected successfully to database!")
return True
except Exception as e:
logger.error(f"Error connecting to database: {str(e)}")
return False
def load_from_csv(self, file_path):
"""Load data from a CSV file."""
try:
logger.info(f"Loading data from CSV file: {file_path}")
if not os.path.exists(file_path):
logger.error(f"CSV file not found: {file_path}")
return None
# Load data in chunks
chunks = []
for chunk in pd.read_csv(file_path, chunksize=self.chunk_size):
# Preprocess chunk
chunk['trx_start_date'] = pd.to_datetime(chunk['trx_start_date'])
chunk['trx_end_date'] = pd.to_datetime(chunk['trx_end_date'])
# Rename columns if needed
if 'd1' in chunk.columns:
chunk = chunk.rename(columns={
'd1': 'trx_type',
'd2': 'trx_subtype',
'd3': 'initiated_by',
'd4': 'customer_id'
})
chunk = chunk.dropna()
chunks.append(chunk)
# Combine all chunks
self.df = pd.concat(chunks, ignore_index=True)
logger.info(f"Successfully loaded {len(self.df)} rows from CSV")
# Basic data validation
logger.info("Performing data validation...")
logger.info(f"Columns in dataset: {self.df.columns.tolist()}")
logger.info(f"Data types:\n{self.df.dtypes}")
logger.info(f"Missing values:\n{self.df.isnull().sum()}")
return self.df
except Exception as e:
logger.error(f"Error loading data from CSV: {str(e)}")
return None
def load_from_db(self):
"""Load and preprocess transaction data from database in chunks."""
if not self.engine:
logger.info("No database connection. Attempting to connect...")
if not self.connect():
logger.error("Failed to establish database connection")
return None
try:
logger.info(f"Loading data from table: {TABLE_NAME}")
# First get total count
with self.engine.connect() as conn:
count_query = text(f"SELECT COUNT(*) FROM {TABLE_NAME}")
total_rows = conn.execute(count_query).scalar()
logger.info(f"Total rows to process: {total_rows}")
# Load data in chunks
chunks = []
offset = 0
while True:
logger.info(f"Loading chunk starting at offset {offset}")
query = f"SELECT * FROM {TABLE_NAME} LIMIT {self.chunk_size} OFFSET {offset}"
chunk = pd.read_sql(query, self.engine)
if chunk.empty:
break
# Preprocess chunk
chunk['trx_start_date'] = pd.to_datetime(chunk['trx_start_date'])
chunk['trx_end_date'] = pd.to_datetime(chunk['trx_end_date'])
# Rename columns
chunk = chunk.rename(columns={
'd1': 'trx_type',
'd2': 'trx_subtype',
'd3': 'initiated_by',
'd4': 'customer_id'
})
chunk = chunk.dropna()
chunks.append(chunk)
offset += self.chunk_size
if offset >= total_rows:
break
# Combine all chunks
self.df = pd.concat(chunks, ignore_index=True)
logger.info(f"Successfully loaded {len(self.df)} rows of data")
# Basic data validation
logger.info("Performing data validation...")
logger.info(f"Columns in dataset: {self.df.columns.tolist()}")
logger.info(f"Data types:\n{self.df.dtypes}")
logger.info(f"Missing values:\n{self.df.isnull().sum()}")
return self.df
except Exception as e:
logger.error(f"Error loading data: {str(e)}")
return None
def load_data(self, source='db', file_path=None):
"""Load data from either database or CSV file."""
if source == 'db':
return self.load_from_db()
elif source == 'csv':
if not file_path:
logger.error("File path must be provided when loading from CSV")
return None
return self.load_from_csv(file_path)
else:
logger.error(f"Invalid source: {source}. Must be 'db' or 'csv'")
return None
def get_data(self):
"""Get the loaded DataFrame."""
if self.df is None:
logger.warning("No data loaded. Call load_data() first.")
return self.df