75 lines
3.0 KiB
Python
75 lines
3.0 KiB
Python
from fastapi import APIRouter, HTTPException, UploadFile, File
|
|
from app.salary_analytics.core.state import state
|
|
from app.utils.logger import logger
|
|
import tempfile, os, time
|
|
from typing import Optional
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
@router.post("/load-data")
|
|
async def load_data(source: str = "db", file: Optional[UploadFile] = File(None)):
|
|
"""
|
|
Load data from either database or CSV file.
|
|
|
|
Args:
|
|
source (str): Source of data ('db' or 'csv')
|
|
file (UploadFile, optional): CSV file to load (required if source is 'csv')
|
|
|
|
Returns:
|
|
dict: Status of data loading
|
|
"""
|
|
start_time = time.time()
|
|
try:
|
|
if source not in ['db', 'csv']:
|
|
logger.error(f"Invalid source: {source}")
|
|
logger.info(f"Load data endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=400, detail="Source must be either 'db' or 'csv'")
|
|
|
|
if source == 'csv' and not file:
|
|
logger.error("No file provided for CSV source")
|
|
logger.info(f"Load data endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
|
|
|
|
if source == 'csv':
|
|
# Save uploaded file temporarily
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.csv') as temp_file:
|
|
content = await file.read()
|
|
temp_file.write(content)
|
|
temp_file_path = temp_file.name
|
|
|
|
try:
|
|
success = state.pipeline.load_data(source='csv', file_path=temp_file_path)
|
|
finally:
|
|
# Clean up temporary file
|
|
os.unlink(temp_file_path)
|
|
else:
|
|
success = state.pipeline.load_data(source='db')
|
|
|
|
if not success:
|
|
logger.error("Failed to load data")
|
|
logger.info(f"Load data endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=500, detail="Failed to load data")
|
|
|
|
response = {
|
|
"status": "success",
|
|
"message": f"Successfully loaded {len(state.pipeline.df)} rows of data",
|
|
"columns": state.pipeline.df.columns.tolist(),
|
|
"row_count": len(state.pipeline.df)
|
|
}
|
|
logger.info(f"Load data endpoint completed in {time.time() - start_time:.2f} seconds")
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Error loading data: {str(e)}")
|
|
logger.info(f"Load data endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/load-data-with-file")
|
|
async def get_file_if_csv(source: str, file: Optional[UploadFile] = File(None)):
|
|
"""Dependency to handle file upload only when source is csv."""
|
|
if source == 'csv' and not file:
|
|
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
|
|
return file
|