Enhance API with data loading functionality and update README.
- Added `/load-data` endpoint to load transaction data from either a database or a CSV file. - Updated `SalaryAnalyticsPipeline` and `DataLoader` to support loading from CSV. - Implemented data validation and error handling for loading processes. - Revised README to include new data loading instructions and workflow steps. - Added checks to ensure data is loaded before running analysis endpoints.
This commit is contained in:
@@ -6,6 +6,7 @@ from sqlalchemy import create_engine, text
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
from .config import DB_CONFIG, TABLE_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,8 +45,49 @@ class DataLoader:
|
||||
logger.error(f"Error connecting to database: {str(e)}")
|
||||
return False
|
||||
|
||||
def load_data(self):
|
||||
"""Load and preprocess transaction data in chunks."""
|
||||
def load_from_csv(self, file_path):
|
||||
"""Load data from a CSV file."""
|
||||
try:
|
||||
logger.info(f"Loading data from CSV file: {file_path}")
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"CSV file not found: {file_path}")
|
||||
return None
|
||||
|
||||
# Load data in chunks
|
||||
chunks = []
|
||||
for chunk in pd.read_csv(file_path, chunksize=self.chunk_size):
|
||||
# Preprocess chunk
|
||||
chunk['trx_start_date'] = pd.to_datetime(chunk['trx_start_date'])
|
||||
chunk['trx_end_date'] = pd.to_datetime(chunk['trx_end_date'])
|
||||
|
||||
# Rename columns if needed
|
||||
if 'd1' in chunk.columns:
|
||||
chunk = chunk.rename(columns={
|
||||
'd1': 'trx_type',
|
||||
'd2': 'trx_subtype',
|
||||
'd3': 'initiated_by',
|
||||
'd4': 'customer_id'
|
||||
})
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
# Combine all chunks
|
||||
self.df = pd.concat(chunks, ignore_index=True)
|
||||
logger.info(f"Successfully loaded {len(self.df)} rows from CSV")
|
||||
|
||||
# Basic data validation
|
||||
logger.info("Performing data validation...")
|
||||
logger.info(f"Columns in dataset: {self.df.columns.tolist()}")
|
||||
logger.info(f"Data types:\n{self.df.dtypes}")
|
||||
logger.info(f"Missing values:\n{self.df.isnull().sum()}")
|
||||
|
||||
return self.df
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading data from CSV: {str(e)}")
|
||||
return None
|
||||
|
||||
def load_from_db(self):
|
||||
"""Load and preprocess transaction data from database in chunks."""
|
||||
if not self.engine:
|
||||
logger.info("No database connection. Attempting to connect...")
|
||||
if not self.connect():
|
||||
@@ -106,6 +148,19 @@ class DataLoader:
|
||||
logger.error(f"Error loading data: {str(e)}")
|
||||
return None
|
||||
|
||||
def load_data(self, source='db', file_path=None):
|
||||
"""Load data from either database or CSV file."""
|
||||
if source == 'db':
|
||||
return self.load_from_db()
|
||||
elif source == 'csv':
|
||||
if not file_path:
|
||||
logger.error("File path must be provided when loading from CSV")
|
||||
return None
|
||||
return self.load_from_csv(file_path)
|
||||
else:
|
||||
logger.error(f"Invalid source: {source}. Must be 'db' or 'csv'")
|
||||
return None
|
||||
|
||||
def get_data(self):
|
||||
"""Get the loaded DataFrame."""
|
||||
if self.df is None:
|
||||
|
||||
Reference in New Issue
Block a user