Enhance API with data loading functionality and update README.
- Added `/load-data` endpoint to load transaction data from either a database or a CSV file. - Updated `SalaryAnalyticsPipeline` and `DataLoader` to support loading from CSV. - Implemented data validation and error handling for loading processes. - Revised README to include new data loading instructions and workflow steps. - Added checks to ensure data is loaded before running analysis endpoints.
This commit is contained in:
+75
-7
@@ -2,7 +2,7 @@
|
||||
FastAPI application for salary analytics.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
@@ -10,9 +10,14 @@ from typing import Optional, Dict
|
||||
import os
|
||||
import socket
|
||||
import logging
|
||||
import pandas as pd
|
||||
import tempfile
|
||||
|
||||
from .main import SalaryAnalyticsPipeline
|
||||
from .config import OUTPUT_PATHS
|
||||
from .data_loader import DataLoader
|
||||
from .salary_predictor import SalaryPredictor
|
||||
from .salary_earner_analyzer import SalaryEarnerAnalyzer
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -37,7 +42,13 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# Global pipeline instance
|
||||
pipeline = None
|
||||
pipeline = SalaryAnalyticsPipeline()
|
||||
|
||||
# Global variables to store loaded data and models
|
||||
data_loader = None
|
||||
df = None
|
||||
salary_predictor = None
|
||||
salary_earner_analyzer = None
|
||||
|
||||
class AnalysisResponse(BaseModel):
|
||||
"""Response model for analysis endpoints."""
|
||||
@@ -45,16 +56,19 @@ class AnalysisResponse(BaseModel):
|
||||
data: Optional[Dict] = None
|
||||
file_path: Optional[str] = None
|
||||
|
||||
def check_data_loaded():
|
||||
"""Check if data is loaded before running analytics."""
|
||||
if pipeline.df is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No data loaded. Please load data first using the /load-data endpoint."
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize the pipeline on startup."""
|
||||
global pipeline
|
||||
try:
|
||||
logger.info("Initializing pipeline...")
|
||||
pipeline = SalaryAnalyticsPipeline()
|
||||
if not pipeline.load_data():
|
||||
logger.error("Failed to load data during startup")
|
||||
raise Exception("Failed to load data during startup")
|
||||
|
||||
# Print network information
|
||||
hostname = socket.gethostname()
|
||||
@@ -86,6 +100,7 @@ async def health_check():
|
||||
async def analyze_keyword():
|
||||
"""Run keyword-based salary transaction analysis."""
|
||||
try:
|
||||
check_data_loaded()
|
||||
logger.info("Starting keyword analysis...")
|
||||
data = pipeline.run_keyword_analysis()
|
||||
logger.info(f"Keyword analysis completed. Found {len(data)} matches")
|
||||
@@ -101,6 +116,7 @@ async def analyze_keyword():
|
||||
async def analyze_consistent_amount():
|
||||
"""Run consistent amount transaction analysis."""
|
||||
try:
|
||||
check_data_loaded()
|
||||
logger.info("Starting consistent amount analysis...")
|
||||
data = pipeline.run_consistent_amount_analysis()
|
||||
logger.info(f"Consistent amount analysis completed. Found {len(data)} matches")
|
||||
@@ -116,6 +132,7 @@ async def analyze_consistent_amount():
|
||||
async def analyze_transaction_type():
|
||||
"""Run transaction type analysis."""
|
||||
try:
|
||||
check_data_loaded()
|
||||
logger.info("Starting transaction type analysis...")
|
||||
data = pipeline.run_transaction_type_analysis()
|
||||
logger.info(f"Transaction type analysis completed. Found {len(data)} matches")
|
||||
@@ -131,6 +148,7 @@ async def analyze_transaction_type():
|
||||
async def generate_reports(background_tasks: BackgroundTasks):
|
||||
"""Generate salary earner reports."""
|
||||
try:
|
||||
check_data_loaded()
|
||||
logger.info("Starting report generation...")
|
||||
reports = pipeline.generate_salary_earner_reports()
|
||||
logger.info("Reports generated successfully")
|
||||
@@ -150,6 +168,7 @@ async def generate_reports(background_tasks: BackgroundTasks):
|
||||
async def train_models():
|
||||
"""Train salary prediction models."""
|
||||
try:
|
||||
check_data_loaded()
|
||||
logger.info("Starting model training...")
|
||||
pipeline.train_salary_prediction_models()
|
||||
logger.info("Models trained successfully")
|
||||
@@ -164,6 +183,7 @@ async def train_models():
|
||||
async def download_report(report_type: str):
|
||||
"""Download generated reports."""
|
||||
try:
|
||||
check_data_loaded()
|
||||
logger.info(f"Attempting to download report: {report_type}")
|
||||
file_paths = {
|
||||
"high_earners": OUTPUT_PATHS["high_earner_details"],
|
||||
@@ -197,6 +217,7 @@ async def download_report(report_type: str):
|
||||
async def run_full_pipeline():
|
||||
"""Run the complete salary analytics pipeline."""
|
||||
try:
|
||||
check_data_loaded()
|
||||
logger.info("Starting full pipeline...")
|
||||
success = pipeline.run_full_pipeline()
|
||||
if not success:
|
||||
@@ -209,4 +230,51 @@ async def run_full_pipeline():
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pipeline: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/load-data")
|
||||
async def load_data(source: str = "db", file: UploadFile = None):
|
||||
"""
|
||||
Load data from either database or CSV file.
|
||||
|
||||
Args:
|
||||
source (str): Source of data ('db' or 'csv')
|
||||
file (UploadFile): CSV file to load (required if source is 'csv')
|
||||
|
||||
Returns:
|
||||
dict: Status of data loading
|
||||
"""
|
||||
try:
|
||||
if source not in ['db', 'csv']:
|
||||
raise HTTPException(status_code=400, detail="Source must be either 'db' or 'csv'")
|
||||
|
||||
if source == 'csv' and not file:
|
||||
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
|
||||
|
||||
if source == 'csv':
|
||||
# Save uploaded file temporarily
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as temp_file:
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
success = pipeline.load_data(source='csv', file_path=temp_file_path)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
os.unlink(temp_file_path)
|
||||
else:
|
||||
success = pipeline.load_data(source='db')
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to load data")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Successfully loaded {len(pipeline.df)} rows of data",
|
||||
"columns": pipeline.df.columns.tolist(),
|
||||
"row_count": len(pipeline.df)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading data: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user