diff --git a/salary_analytics/__pycache__/api.cpython-311.pyc b/salary_analytics/__pycache__/api.cpython-311.pyc index 374ad32..af60713 100644 Binary files a/salary_analytics/__pycache__/api.cpython-311.pyc and b/salary_analytics/__pycache__/api.cpython-311.pyc differ diff --git a/salary_analytics/api.py b/salary_analytics/api.py index 8172a9b..ed4e7e5 100644 --- a/salary_analytics/api.py +++ b/salary_analytics/api.py @@ -2,7 +2,7 @@ FastAPI application for salary analytics. """ -from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File +from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Depends from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel @@ -242,13 +242,13 @@ async def run_full_pipeline(): raise HTTPException(status_code=500, detail=str(e)) @app.post("/load-data") -async def load_data(source: str = "db", file: UploadFile = None): +async def load_data(source: str = "db", file: Optional[UploadFile] = File(None)): """ Load data from either database or CSV file. Args: source (str): Source of data ('db' or 'csv') - file (UploadFile): CSV file to load (required if source is 'csv') + file (UploadFile, optional): CSV file to load (required if source is 'csv') Returns: dict: Status of data loading @@ -288,15 +288,25 @@ async def load_data(source: str = "db", file: UploadFile = None): logger.error(f"Error loading data: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) +async def get_file_if_csv(source: str, file: Optional[UploadFile] = File(None)): + """Dependency to handle file upload only when source is csv.""" + if source == 'csv' and not file: + raise HTTPException(status_code=400, detail="File must be provided when loading from CSV") + return file + @app.post("/run/streaming-pipeline", response_model=List[BatchResponse]) -async def run_streaming_pipeline(source: str = "db", file: UploadFile = None, batch_size: int = 10000): +async def run_streaming_pipeline( + source: str = "db", + batch_size: int = 10000, + file: Optional[UploadFile] = Depends(get_file_if_csv) +): """ Run the complete salary analytics pipeline in batches. Args: source (str): Source of data ('db' or 'csv') - file (UploadFile): CSV file to load (required if source is 'csv') batch_size (int): Number of rows to process in each batch + file (UploadFile, optional): CSV file to load (required if source is 'csv') Returns: List[BatchResponse]: List of responses for each batch processed @@ -305,9 +315,6 @@ async def run_streaming_pipeline(source: str = "db", file: UploadFile = None, ba if source not in ['db', 'csv']: raise HTTPException(status_code=400, detail="Source must be either 'db' or 'csv'") - if source == 'csv' and not file: - raise HTTPException(status_code=400, detail="File must be provided when loading from CSV") - # Initialize data loader data_loader = DataLoader() data_loader.chunk_size = batch_size @@ -326,12 +333,14 @@ async def run_streaming_pipeline(source: str = "db", file: UploadFile = None, ba chunk['trx_start_date'] = pd.to_datetime(chunk['trx_start_date']) chunk['trx_end_date'] = pd.to_datetime(chunk['trx_end_date']) + # Rename columns chunk = chunk.rename(columns={ - 'd1': 'trx_type', - 'd2': 'trx_subtype', - 'd3': 'initiated_by', - 'd4': 'customer_id' - }) + 'd1': 'trx_type', + 'd2': 'trx_subtype', + 'd3': 'initiated_by', + 'd4': 'customer_id' + }) + chunk = chunk.dropna() return chunk