Enhance API with data loading functionality and update README.

- Added `/load-data` endpoint to load transaction data from either a database or a CSV file.
- Updated `SalaryAnalyticsPipeline` and `DataLoader` to support loading from CSV.
- Implemented data validation and error handling for loading processes.
- Revised README to include new data loading instructions and workflow steps.
- Added checks to ensure data is loaded before running analysis endpoints.
This commit is contained in:
2025-05-01 22:57:55 +01:00
parent 7e7094f0fd
commit 8acfb436f3
12 changed files with 205 additions and 29 deletions
+6
View File
@@ -0,0 +1,6 @@
transaction.csv
output/csv/final_table.csv
output/csv/high_earner_details.csv
output/csv/likely_salary_earner.csv
output/plots/consistent_earners_predictions.png
output/plots/hypothesis_overlap.png
+27 -5
View File
@@ -46,7 +46,6 @@ salary_analytics/
└── api.py # FastAPI endpoints
```
## Configuration
The system can be configured through environment variables or the `config.py` file:
@@ -89,12 +88,26 @@ uvicorn salary_analytics.api:app --reload
- `GET /`: Welcome message
- `GET /health`: Health check
2. **Analysis Endpoints**
2. **Data Loading**
- `POST /load-data`: Load transaction data
- Parameters:
- `source`: Data source ('db' or 'csv')
- `file`: CSV file (required if source is 'csv')
- Example:
```bash
# Load from database
curl -X POST "http://localhost:8000/load-data?source=db"
# Load from CSV
curl -X POST "http://localhost:8000/load-data?source=csv" -F "file=@path/to/your/file.csv"
```
3. **Analysis Endpoints**
- `POST /analyze/keyword`: Run keyword analysis
- `POST /analyze/consistent-amount`: Run consistent amount analysis
- `POST /analyze/transaction-type`: Run transaction type analysis
3. **Report Generation**
4. **Report Generation**
- `POST /generate/reports`: Generate all reports
- `GET /download/{report_type}`: Download specific reports
- Available types:
@@ -105,12 +118,21 @@ uvicorn salary_analytics.api:app --reload
- `inconsistent_plot`: Inconsistent earners plot
- `hypothesis_plot`: Hypothesis overlap plot
4. **Model Training**
5. **Model Training**
- `POST /train/models`: Train prediction models
5. **Pipeline**
6. **Pipeline**
- `POST /run/pipeline`: Run complete pipeline
### Workflow
1. Start the API server
2. Load data using the `/load-data` endpoint
3. Run any of the analysis endpoints
4. Generate and download reports as needed
Note: All analysis endpoints require data to be loaded first. If you try to run any analysis without loading data, you'll receive a 400 error with a message to load data first.
## Docker Deployment
1. Build the Docker image:
Binary file not shown.
Binary file not shown.
+75 -7
View File
@@ -2,7 +2,7 @@
FastAPI application for salary analytics.
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
@@ -10,9 +10,14 @@ from typing import Optional, Dict
import os
import socket
import logging
import pandas as pd
import tempfile
from .main import SalaryAnalyticsPipeline
from .config import OUTPUT_PATHS
from .data_loader import DataLoader
from .salary_predictor import SalaryPredictor
from .salary_earner_analyzer import SalaryEarnerAnalyzer
# Configure logging
logging.basicConfig(
@@ -37,7 +42,13 @@ app.add_middleware(
)
# Global pipeline instance
pipeline = None
pipeline = SalaryAnalyticsPipeline()
# Global variables to store loaded data and models
data_loader = None
df = None
salary_predictor = None
salary_earner_analyzer = None
class AnalysisResponse(BaseModel):
"""Response model for analysis endpoints."""
@@ -45,16 +56,19 @@ class AnalysisResponse(BaseModel):
data: Optional[Dict] = None
file_path: Optional[str] = None
def check_data_loaded():
"""Check if data is loaded before running analytics."""
if pipeline.df is None:
raise HTTPException(
status_code=400,
detail="No data loaded. Please load data first using the /load-data endpoint."
)
@app.on_event("startup")
async def startup_event():
"""Initialize the pipeline on startup."""
global pipeline
try:
logger.info("Initializing pipeline...")
pipeline = SalaryAnalyticsPipeline()
if not pipeline.load_data():
logger.error("Failed to load data during startup")
raise Exception("Failed to load data during startup")
# Print network information
hostname = socket.gethostname()
@@ -86,6 +100,7 @@ async def health_check():
async def analyze_keyword():
"""Run keyword-based salary transaction analysis."""
try:
check_data_loaded()
logger.info("Starting keyword analysis...")
data = pipeline.run_keyword_analysis()
logger.info(f"Keyword analysis completed. Found {len(data)} matches")
@@ -101,6 +116,7 @@ async def analyze_keyword():
async def analyze_consistent_amount():
"""Run consistent amount transaction analysis."""
try:
check_data_loaded()
logger.info("Starting consistent amount analysis...")
data = pipeline.run_consistent_amount_analysis()
logger.info(f"Consistent amount analysis completed. Found {len(data)} matches")
@@ -116,6 +132,7 @@ async def analyze_consistent_amount():
async def analyze_transaction_type():
"""Run transaction type analysis."""
try:
check_data_loaded()
logger.info("Starting transaction type analysis...")
data = pipeline.run_transaction_type_analysis()
logger.info(f"Transaction type analysis completed. Found {len(data)} matches")
@@ -131,6 +148,7 @@ async def analyze_transaction_type():
async def generate_reports(background_tasks: BackgroundTasks):
"""Generate salary earner reports."""
try:
check_data_loaded()
logger.info("Starting report generation...")
reports = pipeline.generate_salary_earner_reports()
logger.info("Reports generated successfully")
@@ -150,6 +168,7 @@ async def generate_reports(background_tasks: BackgroundTasks):
async def train_models():
"""Train salary prediction models."""
try:
check_data_loaded()
logger.info("Starting model training...")
pipeline.train_salary_prediction_models()
logger.info("Models trained successfully")
@@ -164,6 +183,7 @@ async def train_models():
async def download_report(report_type: str):
"""Download generated reports."""
try:
check_data_loaded()
logger.info(f"Attempting to download report: {report_type}")
file_paths = {
"high_earners": OUTPUT_PATHS["high_earner_details"],
@@ -197,6 +217,7 @@ async def download_report(report_type: str):
async def run_full_pipeline():
"""Run the complete salary analytics pipeline."""
try:
check_data_loaded()
logger.info("Starting full pipeline...")
success = pipeline.run_full_pipeline()
if not success:
@@ -209,4 +230,51 @@ async def run_full_pipeline():
)
except Exception as e:
logger.error(f"Error in pipeline: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/load-data")
async def load_data(source: str = "db", file: UploadFile = None):
"""
Load data from either database or CSV file.
Args:
source (str): Source of data ('db' or 'csv')
file (UploadFile): CSV file to load (required if source is 'csv')
Returns:
dict: Status of data loading
"""
try:
if source not in ['db', 'csv']:
raise HTTPException(status_code=400, detail="Source must be either 'db' or 'csv'")
if source == 'csv' and not file:
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
if source == 'csv':
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
success = pipeline.load_data(source='csv', file_path=temp_file_path)
finally:
# Clean up temporary file
os.unlink(temp_file_path)
else:
success = pipeline.load_data(source='db')
if not success:
raise HTTPException(status_code=500, detail="Failed to load data")
return {
"status": "success",
"message": f"Successfully loaded {len(pipeline.df)} rows of data",
"columns": pipeline.df.columns.tolist(),
"row_count": len(pipeline.df)
}
except Exception as e:
logger.error(f"Error loading data: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
+10 -4
View File
@@ -40,12 +40,18 @@ class ConsistentAmountAnalyzer:
if cv_threshold is None:
cv_threshold = MODEL_CONFIG['cv_threshold']
self.df = self.df.groupby('accountid').apply(
# Create a copy of the original DataFrame
self.const_df = self.df.copy()
# Calculate consistent amount flags
consistent_flags = self.const_df.groupby('accountid').apply(
lambda group: self.flag_consistent_amounts(group, cv_threshold)
).reset_index(level=0, drop=True)
self.const_df = self.df.copy()
return self.df
# Add the flags to the original DataFrame
self.const_df['is_consistent_amount'] = consistent_flags
return self.const_df
def get_consistent_amount_data(self):
"""Get transactions identified as having consistent amounts."""
+57 -2
View File
@@ -6,6 +6,7 @@ from sqlalchemy import create_engine, text
import pandas as pd
from datetime import datetime
import logging
import os
from .config import DB_CONFIG, TABLE_NAME
logger = logging.getLogger(__name__)
@@ -44,8 +45,49 @@ class DataLoader:
logger.error(f"Error connecting to database: {str(e)}")
return False
def load_data(self):
"""Load and preprocess transaction data in chunks."""
def load_from_csv(self, file_path):
"""Load data from a CSV file."""
try:
logger.info(f"Loading data from CSV file: {file_path}")
if not os.path.exists(file_path):
logger.error(f"CSV file not found: {file_path}")
return None
# Load data in chunks
chunks = []
for chunk in pd.read_csv(file_path, chunksize=self.chunk_size):
# Preprocess chunk
chunk['trx_start_date'] = pd.to_datetime(chunk['trx_start_date'])
chunk['trx_end_date'] = pd.to_datetime(chunk['trx_end_date'])
# Rename columns if needed
if 'd1' in chunk.columns:
chunk = chunk.rename(columns={
'd1': 'trx_type',
'd2': 'trx_subtype',
'd3': 'initiated_by',
'd4': 'customer_id'
})
chunks.append(chunk)
# Combine all chunks
self.df = pd.concat(chunks, ignore_index=True)
logger.info(f"Successfully loaded {len(self.df)} rows from CSV")
# Basic data validation
logger.info("Performing data validation...")
logger.info(f"Columns in dataset: {self.df.columns.tolist()}")
logger.info(f"Data types:\n{self.df.dtypes}")
logger.info(f"Missing values:\n{self.df.isnull().sum()}")
return self.df
except Exception as e:
logger.error(f"Error loading data from CSV: {str(e)}")
return None
def load_from_db(self):
"""Load and preprocess transaction data from database in chunks."""
if not self.engine:
logger.info("No database connection. Attempting to connect...")
if not self.connect():
@@ -106,6 +148,19 @@ class DataLoader:
logger.error(f"Error loading data: {str(e)}")
return None
def load_data(self, source='db', file_path=None):
"""Load data from either database or CSV file."""
if source == 'db':
return self.load_from_db()
elif source == 'csv':
if not file_path:
logger.error("File path must be provided when loading from CSV")
return None
return self.load_from_csv(file_path)
else:
logger.error(f"Invalid source: {source}. Must be 'db' or 'csv'")
return None
def get_data(self):
"""Get the loaded DataFrame."""
if self.df is None:
+27 -7
View File
@@ -23,11 +23,11 @@ class SalaryAnalyticsPipeline:
self.salary_earner_analyzer = None
self.salary_predictor = None
def load_data(self):
def load_data(self, source='db', file_path=None):
"""Load and preprocess the transaction data."""
logger.info("Starting data loading process")
self.data_loader = DataLoader()
self.df = self.data_loader.load_data()
self.df = self.data_loader.load_data(source=source, file_path=file_path)
if self.df is not None:
logger.info(f"Successfully loaded data with {len(self.df)} rows")
else:
@@ -43,7 +43,11 @@ class SalaryAnalyticsPipeline:
logger.info("Starting keyword analysis")
self.keyword_analyzer = KeywordAnalyzer(self.df)
self.keyword_analyzer.identify_salary_transactions()
return self.keyword_analyzer.get_salary_related_data()
keyword_data = self.keyword_analyzer.get_salary_related_data()
# Update main DataFrame with keyword analysis results
self.df['is_salary_related'] = self.df.index.isin(keyword_data.index)
return keyword_data
def run_consistent_amount_analysis(self):
"""Run consistent amount transaction analysis."""
@@ -54,7 +58,11 @@ class SalaryAnalyticsPipeline:
logger.info("Starting consistent amount analysis")
self.consistent_amount_analyzer = ConsistentAmountAnalyzer(self.df)
self.consistent_amount_analyzer.identify_consistent_amount_accounts()
return self.consistent_amount_analyzer.get_consistent_amount_data()
consistent_data = self.consistent_amount_analyzer.get_consistent_amount_data()
# Update main DataFrame with consistent amount analysis results
self.df['is_consistent_amount'] = self.df.index.isin(consistent_data.index)
return consistent_data
def run_transaction_type_analysis(self):
"""Run transaction type analysis."""
@@ -65,7 +73,11 @@ class SalaryAnalyticsPipeline:
logger.info("Starting transaction type analysis")
self.transaction_type_analyzer = TransactionTypeAnalyzer(self.df)
self.transaction_type_analyzer.flag_salary_type_transactions()
return self.transaction_type_analyzer.get_salary_type_data()
type_data = self.transaction_type_analyzer.get_salary_type_data()
# Update main DataFrame with transaction type analysis results
self.df['is_salary_type'] = self.df.index.isin(type_data.index)
return type_data
def generate_salary_earner_reports(self):
"""Generate salary earner reports."""
@@ -73,6 +85,14 @@ class SalaryAnalyticsPipeline:
logger.error("Data not loaded. Call load_data() first.")
raise ValueError("Data not loaded. Call load_data() first.")
# Ensure all analysis flags are present
required_columns = ['is_salary_related', 'is_consistent_amount', 'is_salary_type']
missing_columns = [col for col in required_columns if col not in self.df.columns]
if missing_columns:
logger.error(f"Missing required columns: {missing_columns}")
raise ValueError(f"Missing required columns: {missing_columns}. Run all analyses first.")
logger.info("Starting salary earner report generation")
self.salary_earner_analyzer = SalaryEarnerAnalyzer(self.df)
return self.salary_earner_analyzer.generate_reports()
@@ -96,10 +116,10 @@ class SalaryAnalyticsPipeline:
self.salary_predictor.train_and_evaluate(consistent_accounts, inconsistent_accounts)
def run_full_pipeline(self):
def run_full_pipeline(self, source='db', file_path=None):
"""Run the complete salary analytics pipeline."""
logger.info("Starting full pipeline execution")
if not self.load_data():
if not self.load_data(source=source, file_path=file_path):
logger.error("Failed to load data. Exiting pipeline.")
return False
+3 -4
View File
@@ -83,11 +83,10 @@ class SalaryEarnerAnalyzer:
def analyze_salary_earners(self, final_df):
"""Analyze salary earners and identify high earners."""
high_earners = final_df[final_df['estimated_next_amount'] >= MODEL_CONFIG['high_earner_threshold']]
high_earners['least_inflow_6m'] = high_earners['least_inflow_6m']
count_high = len(high_earners)
high_earners = final_df[final_df['estimated_next_amount'] >= MODEL_CONFIG['high_earner_threshold']].copy()
high_earner_details = high_earners[['accountid', 'least_inflow_6m']].reset_index(drop=True)
count_high = len(high_earners)
return high_earner_details, count_high
def generate_reports(self):