292 lines
13 KiB
Python
292 lines
13 KiB
Python
from fastapi import APIRouter, HTTPException
|
|
from app.salary_analytics.services.main import SalaryAnalyticsPipeline
|
|
from app.salary_analytics.helpers.response_helpers import AnalysisResponse, BatchResponse
|
|
from app.salary_analytics.helpers.data_checks import check_data_loaded
|
|
from app.salary_analytics.services.data_loader import DataLoader
|
|
from app.salary_analytics.core.state import state
|
|
from app.models.db_operations import DatabaseOperations
|
|
from app.config import OUTPUT_PATHS, TABLE_NAME
|
|
from app.utils.logger import logger
|
|
from typing import Optional, List, Union
|
|
from sqlalchemy import text
|
|
from datetime import datetime
|
|
import pandas as pd, os, tempfile, time
|
|
from typing import Optional, Union
|
|
from fastapi import UploadFile, File
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/run/pipeline", response_model=AnalysisResponse)
|
|
async def run_full_pipeline():
|
|
"""Run the complete salary analytics pipeline."""
|
|
start_time = time.time()
|
|
try:
|
|
check_data_loaded()
|
|
logger.info("Starting full pipeline...")
|
|
success = state.pipeline.run_full_pipeline()
|
|
if not success:
|
|
logger.error("Pipeline failed")
|
|
logger.info(f"Full pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=500, detail="Pipeline failed")
|
|
|
|
logger.info("Pipeline completed successfully")
|
|
response = AnalysisResponse(
|
|
message="Pipeline completed successfully"
|
|
)
|
|
logger.info(f"Full pipeline endpoint completed in {time.time() - start_time:.2f} seconds")
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Error in pipeline: {str(e)}")
|
|
logger.info(f"Full pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/run/streaming-pipeline", response_model=List[BatchResponse])
|
|
async def run_streaming_pipeline(
|
|
source: str = "db",
|
|
batch_size: int = 10000,
|
|
file: Optional[Union[UploadFile, str]] = File(None)
|
|
):
|
|
"""
|
|
Run the complete salary analytics pipeline in batches.
|
|
|
|
Args:
|
|
source (str): Source of data ('db' or 'csv')
|
|
batch_size (int): Number of rows to process in each batch
|
|
file (UploadFile, optional): CSV file to load (required if source is 'csv')
|
|
|
|
Returns:
|
|
List[BatchResponse]: List of responses for each batch processed
|
|
"""
|
|
start_time = time.time()
|
|
try:
|
|
if source not in ['db', 'csv']:
|
|
logger.error(f"Invalid source: {source}")
|
|
logger.info(f"Streaming pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=400, detail="Source must be either 'db' or 'csv'")
|
|
|
|
if source == 'csv' and not file:
|
|
logger.error("No file provided for CSV source")
|
|
logger.info(f"Streaming pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
|
|
|
|
# Initialize data loader
|
|
state.data_loader = DataLoader()
|
|
state.data_loader.chunk_size = batch_size
|
|
|
|
# Create output directory for batch results
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
batch_output_dir = os.path.join(os.path.dirname(OUTPUT_PATHS['final_table']), f"batch_results_{timestamp}")
|
|
os.makedirs(batch_output_dir, exist_ok=True)
|
|
|
|
# Initialize database operations
|
|
if not state.data_loader.connect():
|
|
logger.error("Failed to connect to database")
|
|
logger.info(f"Streaming pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=500, detail="Failed to connect to database")
|
|
|
|
db_ops = DatabaseOperations(state.data_loader.engine)
|
|
if not db_ops.create_batch_results_table():
|
|
logger.error("Failed to create batch results table")
|
|
logger.info(f"Streaming pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=500, detail="Failed to create batch results table")
|
|
|
|
responses = []
|
|
batch_number = 0
|
|
batch_start_time = time.time()
|
|
|
|
def preprocess_chunk(chunk):
|
|
"""Preprocess a chunk of data with the same logic as DataLoader."""
|
|
# Convert dates
|
|
chunk['trx_start_date'] = pd.to_datetime(chunk['trx_start_date'])
|
|
chunk['trx_end_date'] = pd.to_datetime(chunk['trx_end_date'])
|
|
|
|
# Rename columns
|
|
chunk = chunk.rename(columns={
|
|
'd1': 'trx_type',
|
|
'd2': 'trx_subtype',
|
|
'd3': 'initiated_by',
|
|
'd4': 'customer_id'
|
|
})
|
|
|
|
chunk = chunk.dropna()
|
|
|
|
return chunk
|
|
|
|
if source == 'csv':
|
|
# Save uploaded file temporarily
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as temp_file:
|
|
content = await file.read()
|
|
temp_file.write(content)
|
|
temp_file_path = temp_file.name
|
|
|
|
try:
|
|
# Process CSV in chunks
|
|
for chunk in pd.read_csv(temp_file_path, chunksize=batch_size):
|
|
batch_number += 1
|
|
logger.info(f"Processing batch {batch_number}")
|
|
|
|
# Preprocess chunk
|
|
chunk = preprocess_chunk(chunk)
|
|
|
|
# Run pipeline on chunk
|
|
state.pipeline = SalaryAnalyticsPipeline()
|
|
state.pipeline.df = chunk
|
|
|
|
try:
|
|
batch_start_time = time.time()
|
|
# Run analyses
|
|
state.pipeline.run_keyword_analysis()
|
|
state.pipeline.run_consistent_amount_analysis()
|
|
state.pipeline.run_transaction_type_analysis()
|
|
|
|
# Generate reports
|
|
reports = state.pipeline.generate_salary_earner_reports()
|
|
|
|
# Add batch metadata to results
|
|
results_df = reports['final_table'].copy()
|
|
results_df['batch_number'] = batch_number
|
|
results_df['total_batches'] = -1 # Unknown for CSV
|
|
results_df['processed_at'] = datetime.now()
|
|
|
|
# Save batch results to CSV
|
|
batch_results_path = os.path.join(batch_output_dir, f"batch_{batch_number}_results.csv")
|
|
results_df.to_csv(batch_results_path, index=False)
|
|
|
|
# Save to database
|
|
db_ops.save_batch_to_db(
|
|
batch_number=batch_number,
|
|
total_batches=-1, # Unknown for CSV
|
|
results_df=results_df,
|
|
status="success"
|
|
)
|
|
|
|
logger.info(f"Batch {batch_number} processed in {time.time() - batch_start_time:.2f} seconds")
|
|
|
|
responses.append(BatchResponse(
|
|
batch_number=batch_number,
|
|
total_batches=-1, # Unknown for CSV
|
|
processed_rows=len(chunk),
|
|
results_path=batch_results_path,
|
|
message=f"Successfully processed batch {batch_number}"
|
|
))
|
|
except Exception as e:
|
|
error_message = str(e)
|
|
logger.error(f"Error processing batch {batch_number}: {error_message}")
|
|
|
|
# Save error to database
|
|
db_ops.save_batch_to_db(
|
|
batch_number=batch_number,
|
|
total_batches=-1,
|
|
results_df=pd.DataFrame(), # Empty DataFrame for error case
|
|
status="error"
|
|
)
|
|
|
|
responses.append(BatchResponse(
|
|
batch_number=batch_number,
|
|
total_batches=-1,
|
|
processed_rows=len(chunk),
|
|
results_path="",
|
|
message=f"Error processing batch {batch_number}: {error_message}"
|
|
))
|
|
finally:
|
|
# Clean up temporary file
|
|
os.unlink(temp_file_path)
|
|
else:
|
|
# Process database in chunks
|
|
if not state.data_loader.connect():
|
|
raise HTTPException(status_code=500, detail="Failed to connect to database")
|
|
|
|
# Get total row count
|
|
with state.data_loader.engine.connect() as conn:
|
|
count_query = text(f"SELECT COUNT(*) FROM {TABLE_NAME}")
|
|
total_rows = conn.execute(count_query).scalar()
|
|
|
|
total_batches = (total_rows + batch_size - 1) // batch_size
|
|
offset = 0
|
|
|
|
while offset < total_rows:
|
|
batch_number += 1
|
|
logger.info(f"Processing batch {batch_number} of {total_batches}")
|
|
|
|
# Load chunk from database
|
|
query = f"SELECT * FROM {TABLE_NAME} LIMIT {batch_size} OFFSET {offset}"
|
|
chunk = pd.read_sql(query, state.data_loader.engine)
|
|
|
|
if chunk.empty:
|
|
break
|
|
|
|
# Preprocess chunk
|
|
chunk = preprocess_chunk(chunk)
|
|
|
|
# Run pipeline on chunk
|
|
pipeline = SalaryAnalyticsPipeline()
|
|
state.pipeline.df = chunk
|
|
|
|
try:
|
|
batch_start_time = time.time()
|
|
# Run analyses
|
|
state.pipeline.run_keyword_analysis()
|
|
state.pipeline.run_consistent_amount_analysis()
|
|
state.pipeline.run_transaction_type_analysis()
|
|
|
|
# Generate reports
|
|
reports = state.pipeline.generate_salary_earner_reports()
|
|
|
|
# Add batch metadata to results
|
|
results_df = reports['final_table'].copy()
|
|
results_df['batch_number'] = batch_number
|
|
results_df['total_batches'] = total_batches
|
|
results_df['processed_at'] = datetime.now()
|
|
|
|
# Save batch results to CSV
|
|
batch_results_path = os.path.join(batch_output_dir, f"batch_{batch_number}_results.csv")
|
|
results_df.to_csv(batch_results_path, index=False)
|
|
|
|
# Save to database
|
|
db_ops.save_batch_to_db(
|
|
batch_number=batch_number,
|
|
total_batches=total_batches,
|
|
results_df=results_df,
|
|
status="success"
|
|
)
|
|
|
|
logger.info(f"Batch {batch_number} of {total_batches} processed in {time.time() - batch_start_time:.2f} seconds")
|
|
|
|
responses.append(BatchResponse(
|
|
batch_number=batch_number,
|
|
total_batches=total_batches,
|
|
processed_rows=len(chunk),
|
|
results_path=batch_results_path,
|
|
message=f"Successfully processed batch {batch_number} of {total_batches}"
|
|
))
|
|
except Exception as e:
|
|
error_message = str(e)
|
|
logger.error(f"Error processing batch {batch_number}: {error_message}")
|
|
|
|
# Save error to database
|
|
db_ops.save_batch_to_db(
|
|
batch_number=batch_number,
|
|
total_batches=total_batches,
|
|
results_df=pd.DataFrame(), # Empty DataFrame for error case
|
|
status="error"
|
|
)
|
|
|
|
responses.append(BatchResponse(
|
|
batch_number=batch_number,
|
|
total_batches=total_batches,
|
|
processed_rows=len(chunk),
|
|
results_path="",
|
|
message=f"Error processing batch {batch_number}: {error_message}"
|
|
))
|
|
|
|
offset += batch_size
|
|
|
|
logger.info(f"Streaming pipeline endpoint completed in {time.time() - start_time:.2f} seconds")
|
|
return responses
|
|
except Exception as e:
|
|
logger.error(f"Error in streaming pipeline: {str(e)}")
|
|
logger.info(f"Streaming pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|