Enhance API with data loading functionality and update README.

- Added `/load-data` endpoint to load transaction data from either a database or a CSV file.
- Updated `SalaryAnalyticsPipeline` and `DataLoader` to support loading from CSV.
- Implemented data validation and error handling for loading processes.
- Revised README to include new data loading instructions and workflow steps.
- Added checks to ensure data is loaded before running analysis endpoints.
This commit is contained in:
2025-05-01 22:57:55 +01:00
parent 7e7094f0fd
commit 8acfb436f3
12 changed files with 205 additions and 29 deletions
+75 -7
View File
@@ -2,7 +2,7 @@
FastAPI application for salary analytics.
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
@@ -10,9 +10,14 @@ from typing import Optional, Dict
import os
import socket
import logging
import pandas as pd
import tempfile
from .main import SalaryAnalyticsPipeline
from .config import OUTPUT_PATHS
from .data_loader import DataLoader
from .salary_predictor import SalaryPredictor
from .salary_earner_analyzer import SalaryEarnerAnalyzer
# Configure logging
logging.basicConfig(
@@ -37,7 +42,13 @@ app.add_middleware(
)
# Global pipeline instance
pipeline = None
pipeline = SalaryAnalyticsPipeline()
# Global variables to store loaded data and models
data_loader = None
df = None
salary_predictor = None
salary_earner_analyzer = None
class AnalysisResponse(BaseModel):
"""Response model for analysis endpoints."""
@@ -45,16 +56,19 @@ class AnalysisResponse(BaseModel):
data: Optional[Dict] = None
file_path: Optional[str] = None
def check_data_loaded():
"""Check if data is loaded before running analytics."""
if pipeline.df is None:
raise HTTPException(
status_code=400,
detail="No data loaded. Please load data first using the /load-data endpoint."
)
@app.on_event("startup")
async def startup_event():
"""Initialize the pipeline on startup."""
global pipeline
try:
logger.info("Initializing pipeline...")
pipeline = SalaryAnalyticsPipeline()
if not pipeline.load_data():
logger.error("Failed to load data during startup")
raise Exception("Failed to load data during startup")
# Print network information
hostname = socket.gethostname()
@@ -86,6 +100,7 @@ async def health_check():
async def analyze_keyword():
"""Run keyword-based salary transaction analysis."""
try:
check_data_loaded()
logger.info("Starting keyword analysis...")
data = pipeline.run_keyword_analysis()
logger.info(f"Keyword analysis completed. Found {len(data)} matches")
@@ -101,6 +116,7 @@ async def analyze_keyword():
async def analyze_consistent_amount():
"""Run consistent amount transaction analysis."""
try:
check_data_loaded()
logger.info("Starting consistent amount analysis...")
data = pipeline.run_consistent_amount_analysis()
logger.info(f"Consistent amount analysis completed. Found {len(data)} matches")
@@ -116,6 +132,7 @@ async def analyze_consistent_amount():
async def analyze_transaction_type():
"""Run transaction type analysis."""
try:
check_data_loaded()
logger.info("Starting transaction type analysis...")
data = pipeline.run_transaction_type_analysis()
logger.info(f"Transaction type analysis completed. Found {len(data)} matches")
@@ -131,6 +148,7 @@ async def analyze_transaction_type():
async def generate_reports(background_tasks: BackgroundTasks):
"""Generate salary earner reports."""
try:
check_data_loaded()
logger.info("Starting report generation...")
reports = pipeline.generate_salary_earner_reports()
logger.info("Reports generated successfully")
@@ -150,6 +168,7 @@ async def generate_reports(background_tasks: BackgroundTasks):
async def train_models():
"""Train salary prediction models."""
try:
check_data_loaded()
logger.info("Starting model training...")
pipeline.train_salary_prediction_models()
logger.info("Models trained successfully")
@@ -164,6 +183,7 @@ async def train_models():
async def download_report(report_type: str):
"""Download generated reports."""
try:
check_data_loaded()
logger.info(f"Attempting to download report: {report_type}")
file_paths = {
"high_earners": OUTPUT_PATHS["high_earner_details"],
@@ -197,6 +217,7 @@ async def download_report(report_type: str):
async def run_full_pipeline():
"""Run the complete salary analytics pipeline."""
try:
check_data_loaded()
logger.info("Starting full pipeline...")
success = pipeline.run_full_pipeline()
if not success:
@@ -209,4 +230,51 @@ async def run_full_pipeline():
)
except Exception as e:
logger.error(f"Error in pipeline: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/load-data")
async def load_data(source: str = "db", file: UploadFile = None):
"""
Load data from either database or CSV file.
Args:
source (str): Source of data ('db' or 'csv')
file (UploadFile): CSV file to load (required if source is 'csv')
Returns:
dict: Status of data loading
"""
try:
if source not in ['db', 'csv']:
raise HTTPException(status_code=400, detail="Source must be either 'db' or 'csv'")
if source == 'csv' and not file:
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
if source == 'csv':
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as temp_file:
content = await file.read()
temp_file.write(content)
temp_file_path = temp_file.name
try:
success = pipeline.load_data(source='csv', file_path=temp_file_path)
finally:
# Clean up temporary file
os.unlink(temp_file_path)
else:
success = pipeline.load_data(source='db')
if not success:
raise HTTPException(status_code=500, detail="Failed to load data")
return {
"status": "success",
"message": f"Successfully loaded {len(pipeline.df)} rows of data",
"columns": pipeline.df.columns.tolist(),
"row_count": len(pipeline.df)
}
except Exception as e:
logger.error(f"Error loading data: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))