[add]: refactoring and cleanup

This commit is contained in:
VivianDee
2025-09-07 23:44:25 +01:00
parent 6de9583aaf
commit d9b6a7e92e
8 changed files with 161 additions and 35 deletions
+1 -1
View File
@@ -20,7 +20,7 @@ def create_app():
migrate.init_app(app, db)
# Register blueprints or CLI commands here if needed
from .commands import commands
from app.analytics.commands import commands
app.cli.add_command(commands.upload_xls_cli)
return app
+32 -32
View File
@@ -4,27 +4,23 @@ FastAPI application for salary analytics.
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Depends
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict, List, Union
import os
import socket
import logging
from typing import Optional, List, Union
import pandas as pd
import tempfile
from datetime import datetime
from sqlalchemy import text, Table, Column, Integer, String, Float, DateTime, MetaData
import numpy as np
from sqlalchemy import text
import warnings
import time
from .analytics.services.main import SalaryAnalyticsPipeline
from .config import OUTPUT_PATHS, TABLE_NAME, BATCH_RESULTS_TABLE
from .data_loader import DataLoader
from .salary_predictor import SalaryPredictor
from .salary_earner_analyzer import SalaryEarnerAnalyzer
from .db_operations import DatabaseOperations
from .analytics.integrations.salary_detect import SalaryDetect
from app.analytics.services.main import SalaryAnalyticsPipeline
from app.config import OUTPUT_PATHS, TABLE_NAME
from app.analytics.services.data_loader import DataLoader
from app.analytics.middlewares.middleware import add_middlewares
from app.models.db_operations import DatabaseOperations
from app.analytics.integrations.salary_detect import SalaryDetect
from app.utils.logger import logger
from app.analytics.helpers.response_helpers import AnalysisResponse, BatchResponse
# Suppress warnings
@@ -38,13 +34,7 @@ app = FastAPI(
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
add_middlewares(app)
# Global pipeline instance
pipeline = SalaryAnalyticsPipeline()
@@ -57,19 +47,7 @@ salary_earner_analyzer = None
salary_detect = SalaryDetect()
class AnalysisResponse(BaseModel):
"""Response model for analysis endpoints."""
message: str
data: Optional[Dict] = None
file_path: Optional[str] = None
class BatchResponse(BaseModel):
"""Response model for batch processing."""
batch_number: int
total_batches: int
processed_rows: int
results_path: str
message: str
def check_data_loaded():
"""Check if data is loaded before running analytics."""
@@ -103,6 +81,8 @@ async def startup_event():
logger.error(f"Error during startup: {str(e)}")
raise
@app.get("/")
async def root():
"""Root endpoint."""
@@ -112,6 +92,8 @@ async def root():
logger.info(f"Root endpoint completed in {time.time() - start_time:.2f} seconds")
return response
@app.get("/health")
async def health_check():
"""Health check endpoint."""
@@ -121,6 +103,8 @@ async def health_check():
logger.info(f"Health check completed in {time.time() - start_time:.2f} seconds")
return response
@app.post("/analyze/keyword", response_model=AnalysisResponse)
async def analyze_keyword():
"""Run keyword-based salary transaction analysis."""
@@ -141,6 +125,8 @@ async def analyze_keyword():
logger.info(f"Keyword analysis endpoint failed after {time.time() - start_time:.2f} seconds")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze/consistent-amount", response_model=AnalysisResponse)
async def analyze_consistent_amount():
"""Run consistent amount transaction analysis."""
@@ -161,6 +147,8 @@ async def analyze_consistent_amount():
logger.info(f"Consistent amount analysis endpoint failed after {time.time() - start_time:.2f} seconds")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze/transaction-type", response_model=AnalysisResponse)
async def analyze_transaction_type():
"""Run transaction type analysis."""
@@ -181,6 +169,8 @@ async def analyze_transaction_type():
logger.info(f"Transaction type analysis endpoint failed after {time.time() - start_time:.2f} seconds")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate/reports", response_model=AnalysisResponse)
async def generate_reports(background_tasks: BackgroundTasks):
"""Generate salary earner reports."""
@@ -205,6 +195,8 @@ async def generate_reports(background_tasks: BackgroundTasks):
logger.info(f"Report generation endpoint failed after {time.time() - start_time:.2f} seconds")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/train/models", response_model=AnalysisResponse)
async def train_models():
"""Train salary prediction models."""
@@ -224,6 +216,8 @@ async def train_models():
logger.info(f"Model training endpoint failed after {time.time() - start_time:.2f} seconds")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/download/{report_type}")
async def download_report(report_type: str):
"""Download generated reports."""
@@ -264,6 +258,8 @@ async def download_report(report_type: str):
logger.info(f"Download endpoint failed after {time.time() - start_time:.2f} seconds")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/run/pipeline", response_model=AnalysisResponse)
async def run_full_pipeline():
"""Run the complete salary analytics pipeline."""
@@ -288,6 +284,8 @@ async def run_full_pipeline():
logger.info(f"Full pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/load-data")
async def load_data(source: str = "db", file: Optional[UploadFile] = File(None)):
"""
@@ -351,6 +349,8 @@ async def get_file_if_csv(source: str, file: Optional[UploadFile] = File(None)):
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
return file
@app.post("/run/streaming-pipeline", response_model=List[BatchResponse])
async def run_streaming_pipeline(
source: str = "db",
+17
View File
@@ -0,0 +1,17 @@
from typing import Optional, Dict, List, Union
from pydantic import BaseModel
class AnalysisResponse(BaseModel):
"""Response model for analysis endpoints."""
message: str
data: Optional[Dict] = None
file_path: Optional[str] = None
class BatchResponse(BaseModel):
"""Response model for batch processing."""
batch_number: int
total_batches: int
processed_rows: int
results_path: str
message: str
+11
View File
@@ -0,0 +1,11 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI
def add_middlewares(app: FastAPI):
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
+2 -1
View File
@@ -22,9 +22,10 @@ os.makedirs(PLOTS_DIR, exist_ok=True)
os.makedirs(CSV_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
# Database Configuration
DB_CONFIG = {
"user": os.getenv("DB_USER"), # Default value as fallback
"user": os.getenv("DB_USER"),
"password": os.getenv("DB_PASSWORD"),
"name": os.getenv("DB_NAME"),
"port": os.getenv("DB_PORT"),
+97
View File
@@ -0,0 +1,97 @@
from sqlalchemy import Column, Integer, String, DateTime, Numeric, Boolean, func
from sqlalchemy.orm import declarative_base, Session
from datetime import datetime
from app.utils.logger import logger
from app.extensions import db
class BatchResult(db.Model):
__tablename__ = "salary_analytics_batch_results"
id = Column(Integer, primary_key=True, autoincrement=True)
batch_number = Column(Integer, nullable=False)
total_batches = Column(Integer, nullable=False)
processed_at = Column(DateTime, default=datetime.utcnow)
accountid = Column(String, nullable=False)
num_months = Column(Integer)
least_inflow_6m = Column(Numeric)
avg_monthly_salary = Column(Numeric)
estimated_next_amount = Column(Numeric)
estimated_next_date = Column(DateTime)
is_45day_salary = Column(Boolean, default=False)
is_2months_salary = Column(Boolean, default=False)
status = Column(String, default="success")
@classmethod
def save_batch(cls, session: Session, batch_number, total_batches, results_df, status="success"):
"""Save batch results into DB using ORM bulk insert."""
try:
results_df["batch_number"] = batch_number
results_df["total_batches"] = total_batches
results_df["processed_at"] = datetime.utcnow()
results_df["status"] = status
# Normalize boolean columns
results_df["is_45day_salary"] = results_df.get("45daysalary", False)
results_df["is_2months_salary"] = results_df.get("2monthssalary", False)
# Convert to list of ORM objects
records = [
cls(**row)
for row in results_df.to_dict("records")
]
session.bulk_save_objects(records)
session.commit()
logger.info(f"Saved batch {batch_number} successfully.")
return True
except Exception as e:
session.rollback()
logger.error(f"Error saving batch {batch_number}: {str(e)}")
return False
@classmethod
def get_batch_status(cls, session: Session, batch_number: int):
"""Return summary info about one batch."""
try:
result = (
session.query(
cls.batch_number,
cls.total_batches,
cls.processed_at,
func.count().label("total_records"),
func.sum(func.case((cls.status == "success", 1), else_=0)).label("successful_records"),
func.sum(func.case((cls.status == "error", 1), else_=0)).label("failed_records"),
)
.filter(cls.batch_number == batch_number)
.group_by(cls.batch_number, cls.total_batches, cls.processed_at)
.order_by(cls.processed_at.desc())
.first()
)
return dict(result._mapping) if result else None
except Exception as e:
logger.error(f"Error fetching batch {batch_number} status: {str(e)}")
return None
@classmethod
def get_all_batches(cls, session: Session):
"""Return summaries for all batches."""
try:
results = (
session.query(
cls.batch_number,
cls.total_batches,
cls.processed_at,
func.count().label("total_records"),
func.sum(func.case((cls.status == "success", 1), else_=0)).label("successful_records"),
func.sum(func.case((cls.status == "error", 1), else_=0)).label("failed_records"),
)
.group_by(cls.batch_number, cls.total_batches, cls.processed_at)
.order_by(cls.batch_number)
.all()
)
return [dict(r._mapping) for r in results]
except Exception as e:
logger.error(f"Error fetching all batches: {str(e)}")
return []
@@ -3,7 +3,7 @@ Database operations module for salary analytics.
"""
from sqlalchemy import text
from .config import BATCH_RESULTS_TABLE
from ..config import BATCH_RESULTS_TABLE
from datetime import datetime
from app.utils.logger import logger