Refactor data loading and streaming pipeline endpoints for improved file handling

- Updated `/load-data` endpoint to make the file parameter optional and added validation for CSV uploads.
- Introduced a new dependency function `get_file_if_csv` to streamline file checks when loading data from CSV.
- Enhanced `/run/streaming-pipeline` endpoint to utilize the new file handling logic.
- Improved code readability by restructuring file renaming logic.
This commit is contained in:
2025-05-03 15:40:50 +01:00
parent 9c429caa56
commit a060fa69c5
2 changed files with 22 additions and 13 deletions
Binary file not shown.
+22 -13
View File
@@ -2,7 +2,7 @@
FastAPI application for salary analytics.
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Depends
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
@@ -242,13 +242,13 @@ async def run_full_pipeline():
raise HTTPException(status_code=500, detail=str(e))
@app.post("/load-data")
async def load_data(source: str = "db", file: UploadFile = None):
async def load_data(source: str = "db", file: Optional[UploadFile] = File(None)):
"""
Load data from either database or CSV file.
Args:
source (str): Source of data ('db' or 'csv')
file (UploadFile): CSV file to load (required if source is 'csv')
file (UploadFile, optional): CSV file to load (required if source is 'csv')
Returns:
dict: Status of data loading
@@ -288,15 +288,25 @@ async def load_data(source: str = "db", file: UploadFile = None):
logger.error(f"Error loading data: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def get_file_if_csv(source: str, file: Optional[UploadFile] = File(None)):
"""Dependency to handle file upload only when source is csv."""
if source == 'csv' and not file:
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
return file
@app.post("/run/streaming-pipeline", response_model=List[BatchResponse])
async def run_streaming_pipeline(source: str = "db", file: UploadFile = None, batch_size: int = 10000):
async def run_streaming_pipeline(
source: str = "db",
batch_size: int = 10000,
file: Optional[UploadFile] = Depends(get_file_if_csv)
):
"""
Run the complete salary analytics pipeline in batches.
Args:
source (str): Source of data ('db' or 'csv')
file (UploadFile): CSV file to load (required if source is 'csv')
batch_size (int): Number of rows to process in each batch
file (UploadFile, optional): CSV file to load (required if source is 'csv')
Returns:
List[BatchResponse]: List of responses for each batch processed
@@ -305,9 +315,6 @@ async def run_streaming_pipeline(source: str = "db", file: UploadFile = None, ba
if source not in ['db', 'csv']:
raise HTTPException(status_code=400, detail="Source must be either 'db' or 'csv'")
if source == 'csv' and not file:
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
# Initialize data loader
data_loader = DataLoader()
data_loader.chunk_size = batch_size
@@ -326,12 +333,14 @@ async def run_streaming_pipeline(source: str = "db", file: UploadFile = None, ba
chunk['trx_start_date'] = pd.to_datetime(chunk['trx_start_date'])
chunk['trx_end_date'] = pd.to_datetime(chunk['trx_end_date'])
# Rename columns
chunk = chunk.rename(columns={
'd1': 'trx_type',
'd2': 'trx_subtype',
'd3': 'initiated_by',
'd4': 'customer_id'
})
'd1': 'trx_type',
'd2': 'trx_subtype',
'd3': 'initiated_by',
'd4': 'customer_id'
})
chunk = chunk.dropna()
return chunk