[add]: refactoring and cleanup
This commit is contained in:
+1
-1
@@ -20,7 +20,7 @@ def create_app():
|
|||||||
migrate.init_app(app, db)
|
migrate.init_app(app, db)
|
||||||
|
|
||||||
# Register blueprints or CLI commands here if needed
|
# Register blueprints or CLI commands here if needed
|
||||||
from .commands import commands
|
from app.analytics.commands import commands
|
||||||
app.cli.add_command(commands.upload_xls_cli)
|
app.cli.add_command(commands.upload_xls_cli)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
@@ -4,27 +4,23 @@ FastAPI application for salary analytics.
|
|||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Depends
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Depends
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional, Dict, List, Union
|
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import logging
|
from typing import Optional, List, Union
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tempfile
|
import tempfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy import text, Table, Column, Integer, String, Float, DateTime, MetaData
|
from sqlalchemy import text
|
||||||
import numpy as np
|
|
||||||
import warnings
|
import warnings
|
||||||
import time
|
import time
|
||||||
from .analytics.services.main import SalaryAnalyticsPipeline
|
from app.analytics.services.main import SalaryAnalyticsPipeline
|
||||||
from .config import OUTPUT_PATHS, TABLE_NAME, BATCH_RESULTS_TABLE
|
from app.config import OUTPUT_PATHS, TABLE_NAME
|
||||||
from .data_loader import DataLoader
|
from app.analytics.services.data_loader import DataLoader
|
||||||
from .salary_predictor import SalaryPredictor
|
from app.analytics.middlewares.middleware import add_middlewares
|
||||||
from .salary_earner_analyzer import SalaryEarnerAnalyzer
|
from app.models.db_operations import DatabaseOperations
|
||||||
from .db_operations import DatabaseOperations
|
from app.analytics.integrations.salary_detect import SalaryDetect
|
||||||
from .analytics.integrations.salary_detect import SalaryDetect
|
|
||||||
from app.utils.logger import logger
|
from app.utils.logger import logger
|
||||||
|
from app.analytics.helpers.response_helpers import AnalysisResponse, BatchResponse
|
||||||
|
|
||||||
|
|
||||||
# Suppress warnings
|
# Suppress warnings
|
||||||
@@ -38,13 +34,7 @@ app = FastAPI(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add CORS middleware
|
# Add CORS middleware
|
||||||
app.add_middleware(
|
add_middlewares(app)
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"], # Allows all origins
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"], # Allows all methods
|
|
||||||
allow_headers=["*"], # Allows all headers
|
|
||||||
)
|
|
||||||
|
|
||||||
# Global pipeline instance
|
# Global pipeline instance
|
||||||
pipeline = SalaryAnalyticsPipeline()
|
pipeline = SalaryAnalyticsPipeline()
|
||||||
@@ -57,19 +47,7 @@ salary_earner_analyzer = None
|
|||||||
|
|
||||||
salary_detect = SalaryDetect()
|
salary_detect = SalaryDetect()
|
||||||
|
|
||||||
class AnalysisResponse(BaseModel):
|
|
||||||
"""Response model for analysis endpoints."""
|
|
||||||
message: str
|
|
||||||
data: Optional[Dict] = None
|
|
||||||
file_path: Optional[str] = None
|
|
||||||
|
|
||||||
class BatchResponse(BaseModel):
|
|
||||||
"""Response model for batch processing."""
|
|
||||||
batch_number: int
|
|
||||||
total_batches: int
|
|
||||||
processed_rows: int
|
|
||||||
results_path: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
def check_data_loaded():
|
def check_data_loaded():
|
||||||
"""Check if data is loaded before running analytics."""
|
"""Check if data is loaded before running analytics."""
|
||||||
@@ -103,6 +81,8 @@ async def startup_event():
|
|||||||
logger.error(f"Error during startup: {str(e)}")
|
logger.error(f"Error during startup: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
"""Root endpoint."""
|
"""Root endpoint."""
|
||||||
@@ -112,6 +92,8 @@ async def root():
|
|||||||
logger.info(f"Root endpoint completed in {time.time() - start_time:.2f} seconds")
|
logger.info(f"Root endpoint completed in {time.time() - start_time:.2f} seconds")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Health check endpoint."""
|
"""Health check endpoint."""
|
||||||
@@ -121,6 +103,8 @@ async def health_check():
|
|||||||
logger.info(f"Health check completed in {time.time() - start_time:.2f} seconds")
|
logger.info(f"Health check completed in {time.time() - start_time:.2f} seconds")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/analyze/keyword", response_model=AnalysisResponse)
|
@app.post("/analyze/keyword", response_model=AnalysisResponse)
|
||||||
async def analyze_keyword():
|
async def analyze_keyword():
|
||||||
"""Run keyword-based salary transaction analysis."""
|
"""Run keyword-based salary transaction analysis."""
|
||||||
@@ -141,6 +125,8 @@ async def analyze_keyword():
|
|||||||
logger.info(f"Keyword analysis endpoint failed after {time.time() - start_time:.2f} seconds")
|
logger.info(f"Keyword analysis endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/analyze/consistent-amount", response_model=AnalysisResponse)
|
@app.post("/analyze/consistent-amount", response_model=AnalysisResponse)
|
||||||
async def analyze_consistent_amount():
|
async def analyze_consistent_amount():
|
||||||
"""Run consistent amount transaction analysis."""
|
"""Run consistent amount transaction analysis."""
|
||||||
@@ -161,6 +147,8 @@ async def analyze_consistent_amount():
|
|||||||
logger.info(f"Consistent amount analysis endpoint failed after {time.time() - start_time:.2f} seconds")
|
logger.info(f"Consistent amount analysis endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/analyze/transaction-type", response_model=AnalysisResponse)
|
@app.post("/analyze/transaction-type", response_model=AnalysisResponse)
|
||||||
async def analyze_transaction_type():
|
async def analyze_transaction_type():
|
||||||
"""Run transaction type analysis."""
|
"""Run transaction type analysis."""
|
||||||
@@ -181,6 +169,8 @@ async def analyze_transaction_type():
|
|||||||
logger.info(f"Transaction type analysis endpoint failed after {time.time() - start_time:.2f} seconds")
|
logger.info(f"Transaction type analysis endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate/reports", response_model=AnalysisResponse)
|
@app.post("/generate/reports", response_model=AnalysisResponse)
|
||||||
async def generate_reports(background_tasks: BackgroundTasks):
|
async def generate_reports(background_tasks: BackgroundTasks):
|
||||||
"""Generate salary earner reports."""
|
"""Generate salary earner reports."""
|
||||||
@@ -205,6 +195,8 @@ async def generate_reports(background_tasks: BackgroundTasks):
|
|||||||
logger.info(f"Report generation endpoint failed after {time.time() - start_time:.2f} seconds")
|
logger.info(f"Report generation endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/train/models", response_model=AnalysisResponse)
|
@app.post("/train/models", response_model=AnalysisResponse)
|
||||||
async def train_models():
|
async def train_models():
|
||||||
"""Train salary prediction models."""
|
"""Train salary prediction models."""
|
||||||
@@ -224,6 +216,8 @@ async def train_models():
|
|||||||
logger.info(f"Model training endpoint failed after {time.time() - start_time:.2f} seconds")
|
logger.info(f"Model training endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/download/{report_type}")
|
@app.get("/download/{report_type}")
|
||||||
async def download_report(report_type: str):
|
async def download_report(report_type: str):
|
||||||
"""Download generated reports."""
|
"""Download generated reports."""
|
||||||
@@ -264,6 +258,8 @@ async def download_report(report_type: str):
|
|||||||
logger.info(f"Download endpoint failed after {time.time() - start_time:.2f} seconds")
|
logger.info(f"Download endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/run/pipeline", response_model=AnalysisResponse)
|
@app.post("/run/pipeline", response_model=AnalysisResponse)
|
||||||
async def run_full_pipeline():
|
async def run_full_pipeline():
|
||||||
"""Run the complete salary analytics pipeline."""
|
"""Run the complete salary analytics pipeline."""
|
||||||
@@ -288,6 +284,8 @@ async def run_full_pipeline():
|
|||||||
logger.info(f"Full pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
logger.info(f"Full pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/load-data")
|
@app.post("/load-data")
|
||||||
async def load_data(source: str = "db", file: Optional[UploadFile] = File(None)):
|
async def load_data(source: str = "db", file: Optional[UploadFile] = File(None)):
|
||||||
"""
|
"""
|
||||||
@@ -351,6 +349,8 @@ async def get_file_if_csv(source: str, file: Optional[UploadFile] = File(None)):
|
|||||||
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
|
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
|
||||||
return file
|
return file
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/run/streaming-pipeline", response_model=List[BatchResponse])
|
@app.post("/run/streaming-pipeline", response_model=List[BatchResponse])
|
||||||
async def run_streaming_pipeline(
|
async def run_streaming_pipeline(
|
||||||
source: str = "db",
|
source: str = "db",
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
from typing import Optional, Dict, List, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisResponse(BaseModel):
|
||||||
|
"""Response model for analysis endpoints."""
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict] = None
|
||||||
|
file_path: Optional[str] = None
|
||||||
|
|
||||||
|
class BatchResponse(BaseModel):
|
||||||
|
"""Response model for batch processing."""
|
||||||
|
batch_number: int
|
||||||
|
total_batches: int
|
||||||
|
processed_rows: int
|
||||||
|
results_path: str
|
||||||
|
message: str
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
def add_middlewares(app: FastAPI):
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
+2
-1
@@ -22,9 +22,10 @@ os.makedirs(PLOTS_DIR, exist_ok=True)
|
|||||||
os.makedirs(CSV_DIR, exist_ok=True)
|
os.makedirs(CSV_DIR, exist_ok=True)
|
||||||
os.makedirs(MODEL_DIR, exist_ok=True)
|
os.makedirs(MODEL_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
# Database Configuration
|
# Database Configuration
|
||||||
DB_CONFIG = {
|
DB_CONFIG = {
|
||||||
"user": os.getenv("DB_USER"), # Default value as fallback
|
"user": os.getenv("DB_USER"),
|
||||||
"password": os.getenv("DB_PASSWORD"),
|
"password": os.getenv("DB_PASSWORD"),
|
||||||
"name": os.getenv("DB_NAME"),
|
"name": os.getenv("DB_NAME"),
|
||||||
"port": os.getenv("DB_PORT"),
|
"port": os.getenv("DB_PORT"),
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
from sqlalchemy import Column, Integer, String, DateTime, Numeric, Boolean, func
|
||||||
|
from sqlalchemy.orm import declarative_base, Session
|
||||||
|
from datetime import datetime
|
||||||
|
from app.utils.logger import logger
|
||||||
|
from app.extensions import db
|
||||||
|
|
||||||
|
|
||||||
|
class BatchResult(db.Model):
|
||||||
|
__tablename__ = "salary_analytics_batch_results"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
batch_number = Column(Integer, nullable=False)
|
||||||
|
total_batches = Column(Integer, nullable=False)
|
||||||
|
processed_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
accountid = Column(String, nullable=False)
|
||||||
|
num_months = Column(Integer)
|
||||||
|
least_inflow_6m = Column(Numeric)
|
||||||
|
avg_monthly_salary = Column(Numeric)
|
||||||
|
estimated_next_amount = Column(Numeric)
|
||||||
|
estimated_next_date = Column(DateTime)
|
||||||
|
is_45day_salary = Column(Boolean, default=False)
|
||||||
|
is_2months_salary = Column(Boolean, default=False)
|
||||||
|
status = Column(String, default="success")
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_batch(cls, session: Session, batch_number, total_batches, results_df, status="success"):
|
||||||
|
"""Save batch results into DB using ORM bulk insert."""
|
||||||
|
try:
|
||||||
|
results_df["batch_number"] = batch_number
|
||||||
|
results_df["total_batches"] = total_batches
|
||||||
|
results_df["processed_at"] = datetime.utcnow()
|
||||||
|
results_df["status"] = status
|
||||||
|
|
||||||
|
# Normalize boolean columns
|
||||||
|
results_df["is_45day_salary"] = results_df.get("45daysalary", False)
|
||||||
|
results_df["is_2months_salary"] = results_df.get("2monthssalary", False)
|
||||||
|
|
||||||
|
# Convert to list of ORM objects
|
||||||
|
records = [
|
||||||
|
cls(**row)
|
||||||
|
for row in results_df.to_dict("records")
|
||||||
|
]
|
||||||
|
|
||||||
|
session.bulk_save_objects(records)
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"Saved batch {batch_number} successfully.")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
logger.error(f"Error saving batch {batch_number}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_batch_status(cls, session: Session, batch_number: int):
|
||||||
|
"""Return summary info about one batch."""
|
||||||
|
try:
|
||||||
|
result = (
|
||||||
|
session.query(
|
||||||
|
cls.batch_number,
|
||||||
|
cls.total_batches,
|
||||||
|
cls.processed_at,
|
||||||
|
func.count().label("total_records"),
|
||||||
|
func.sum(func.case((cls.status == "success", 1), else_=0)).label("successful_records"),
|
||||||
|
func.sum(func.case((cls.status == "error", 1), else_=0)).label("failed_records"),
|
||||||
|
)
|
||||||
|
.filter(cls.batch_number == batch_number)
|
||||||
|
.group_by(cls.batch_number, cls.total_batches, cls.processed_at)
|
||||||
|
.order_by(cls.processed_at.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return dict(result._mapping) if result else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching batch {batch_number} status: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_batches(cls, session: Session):
|
||||||
|
"""Return summaries for all batches."""
|
||||||
|
try:
|
||||||
|
results = (
|
||||||
|
session.query(
|
||||||
|
cls.batch_number,
|
||||||
|
cls.total_batches,
|
||||||
|
cls.processed_at,
|
||||||
|
func.count().label("total_records"),
|
||||||
|
func.sum(func.case((cls.status == "success", 1), else_=0)).label("successful_records"),
|
||||||
|
func.sum(func.case((cls.status == "error", 1), else_=0)).label("failed_records"),
|
||||||
|
)
|
||||||
|
.group_by(cls.batch_number, cls.total_batches, cls.processed_at)
|
||||||
|
.order_by(cls.batch_number)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [dict(r._mapping) for r in results]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching all batches: {str(e)}")
|
||||||
|
return []
|
||||||
@@ -3,7 +3,7 @@ Database operations module for salary analytics.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from .config import BATCH_RESULTS_TABLE
|
from ..config import BATCH_RESULTS_TABLE
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from app.utils.logger import logger
|
from app.utils.logger import logger
|
||||||
|
|
||||||
Reference in New Issue
Block a user