34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
import time
|
|
import logging
|
|
from fastapi import APIRouter, HTTPException
|
|
from app.salary_analytics.services.main import SalaryAnalyticsPipeline
|
|
from app.salary_analytics.helpers.data_checks import check_data_loaded
|
|
from app.salary_analytics.helpers.response_helpers import AnalysisResponse
|
|
from app.salary_analytics.core.state import state
|
|
from app.utils.logger import logger
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/train/models", response_model=AnalysisResponse)
|
|
async def train_models():
|
|
"""Train salary prediction models."""
|
|
start_time = time.time()
|
|
try:
|
|
check_data_loaded()
|
|
logger.info("Starting model training...")
|
|
state.pipeline.train_salary_prediction_models()
|
|
logger.info("Models trained successfully")
|
|
response = AnalysisResponse(
|
|
message="Models trained successfully"
|
|
)
|
|
logger.info(f"Model training endpoint completed in {time.time() - start_time:.2f} seconds")
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Error in model training: {str(e)}")
|
|
logger.info(f"Model training endpoint failed after {time.time() - start_time:.2f} seconds")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|