Refactor data loading and streaming pipeline endpoints for improved file handling
- Updated `/load-data` endpoint to make the file parameter optional and added validation for CSV uploads. - Introduced a new dependency function `get_file_if_csv` to streamline file checks when loading data from CSV. - Enhanced `/run/streaming-pipeline` endpoint to utilize the new file handling logic. - Improved code readability by restructuring file renaming logic.
This commit is contained in:
Binary file not shown.
+22
-13
@@ -2,7 +2,7 @@
|
|||||||
FastAPI application for salary analytics.
|
FastAPI application for salary analytics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Depends
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -242,13 +242,13 @@ async def run_full_pipeline():
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post("/load-data")
|
@app.post("/load-data")
|
||||||
async def load_data(source: str = "db", file: UploadFile = None):
|
async def load_data(source: str = "db", file: Optional[UploadFile] = File(None)):
|
||||||
"""
|
"""
|
||||||
Load data from either database or CSV file.
|
Load data from either database or CSV file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source (str): Source of data ('db' or 'csv')
|
source (str): Source of data ('db' or 'csv')
|
||||||
file (UploadFile): CSV file to load (required if source is 'csv')
|
file (UploadFile, optional): CSV file to load (required if source is 'csv')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Status of data loading
|
dict: Status of data loading
|
||||||
@@ -288,15 +288,25 @@ async def load_data(source: str = "db", file: UploadFile = None):
|
|||||||
logger.error(f"Error loading data: {str(e)}")
|
logger.error(f"Error loading data: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
async def get_file_if_csv(source: str, file: Optional[UploadFile] = File(None)):
|
||||||
|
"""Dependency to handle file upload only when source is csv."""
|
||||||
|
if source == 'csv' and not file:
|
||||||
|
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
|
||||||
|
return file
|
||||||
|
|
||||||
@app.post("/run/streaming-pipeline", response_model=List[BatchResponse])
|
@app.post("/run/streaming-pipeline", response_model=List[BatchResponse])
|
||||||
async def run_streaming_pipeline(source: str = "db", file: UploadFile = None, batch_size: int = 10000):
|
async def run_streaming_pipeline(
|
||||||
|
source: str = "db",
|
||||||
|
batch_size: int = 10000,
|
||||||
|
file: Optional[UploadFile] = Depends(get_file_if_csv)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Run the complete salary analytics pipeline in batches.
|
Run the complete salary analytics pipeline in batches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source (str): Source of data ('db' or 'csv')
|
source (str): Source of data ('db' or 'csv')
|
||||||
file (UploadFile): CSV file to load (required if source is 'csv')
|
|
||||||
batch_size (int): Number of rows to process in each batch
|
batch_size (int): Number of rows to process in each batch
|
||||||
|
file (UploadFile, optional): CSV file to load (required if source is 'csv')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[BatchResponse]: List of responses for each batch processed
|
List[BatchResponse]: List of responses for each batch processed
|
||||||
@@ -305,9 +315,6 @@ async def run_streaming_pipeline(source: str = "db", file: UploadFile = None, ba
|
|||||||
if source not in ['db', 'csv']:
|
if source not in ['db', 'csv']:
|
||||||
raise HTTPException(status_code=400, detail="Source must be either 'db' or 'csv'")
|
raise HTTPException(status_code=400, detail="Source must be either 'db' or 'csv'")
|
||||||
|
|
||||||
if source == 'csv' and not file:
|
|
||||||
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
|
|
||||||
|
|
||||||
# Initialize data loader
|
# Initialize data loader
|
||||||
data_loader = DataLoader()
|
data_loader = DataLoader()
|
||||||
data_loader.chunk_size = batch_size
|
data_loader.chunk_size = batch_size
|
||||||
@@ -326,12 +333,14 @@ async def run_streaming_pipeline(source: str = "db", file: UploadFile = None, ba
|
|||||||
chunk['trx_start_date'] = pd.to_datetime(chunk['trx_start_date'])
|
chunk['trx_start_date'] = pd.to_datetime(chunk['trx_start_date'])
|
||||||
chunk['trx_end_date'] = pd.to_datetime(chunk['trx_end_date'])
|
chunk['trx_end_date'] = pd.to_datetime(chunk['trx_end_date'])
|
||||||
|
|
||||||
|
# Rename columns
|
||||||
chunk = chunk.rename(columns={
|
chunk = chunk.rename(columns={
|
||||||
'd1': 'trx_type',
|
'd1': 'trx_type',
|
||||||
'd2': 'trx_subtype',
|
'd2': 'trx_subtype',
|
||||||
'd3': 'initiated_by',
|
'd3': 'initiated_by',
|
||||||
'd4': 'customer_id'
|
'd4': 'customer_id'
|
||||||
})
|
})
|
||||||
|
|
||||||
chunk = chunk.dropna()
|
chunk = chunk.dropna()
|
||||||
|
|
||||||
return chunk
|
return chunk
|
||||||
|
|||||||
Reference in New Issue
Block a user