[add]: refactoring and cleanup
This commit is contained in:
+1
-1
@@ -20,7 +20,7 @@ def create_app():
|
||||
migrate.init_app(app, db)
|
||||
|
||||
# Register blueprints or CLI commands here if needed
|
||||
from .commands import commands
|
||||
from app.analytics.commands import commands
|
||||
app.cli.add_command(commands.upload_xls_cli)
|
||||
|
||||
return app
|
||||
@@ -4,27 +4,23 @@ FastAPI application for salary analytics.
|
||||
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict, List, Union
|
||||
import os
|
||||
import socket
|
||||
import logging
|
||||
from typing import Optional, List, Union
|
||||
import pandas as pd
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from sqlalchemy import text, Table, Column, Integer, String, Float, DateTime, MetaData
|
||||
import numpy as np
|
||||
from sqlalchemy import text
|
||||
import warnings
|
||||
import time
|
||||
from .analytics.services.main import SalaryAnalyticsPipeline
|
||||
from .config import OUTPUT_PATHS, TABLE_NAME, BATCH_RESULTS_TABLE
|
||||
from .data_loader import DataLoader
|
||||
from .salary_predictor import SalaryPredictor
|
||||
from .salary_earner_analyzer import SalaryEarnerAnalyzer
|
||||
from .db_operations import DatabaseOperations
|
||||
from .analytics.integrations.salary_detect import SalaryDetect
|
||||
from app.analytics.services.main import SalaryAnalyticsPipeline
|
||||
from app.config import OUTPUT_PATHS, TABLE_NAME
|
||||
from app.analytics.services.data_loader import DataLoader
|
||||
from app.analytics.middlewares.middleware import add_middlewares
|
||||
from app.models.db_operations import DatabaseOperations
|
||||
from app.analytics.integrations.salary_detect import SalaryDetect
|
||||
from app.utils.logger import logger
|
||||
from app.analytics.helpers.response_helpers import AnalysisResponse, BatchResponse
|
||||
|
||||
|
||||
# Suppress warnings
|
||||
@@ -38,13 +34,7 @@ app = FastAPI(
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
add_middlewares(app)
|
||||
|
||||
# Global pipeline instance
|
||||
pipeline = SalaryAnalyticsPipeline()
|
||||
@@ -57,19 +47,7 @@ salary_earner_analyzer = None
|
||||
|
||||
salary_detect = SalaryDetect()
|
||||
|
||||
class AnalysisResponse(BaseModel):
|
||||
"""Response model for analysis endpoints."""
|
||||
message: str
|
||||
data: Optional[Dict] = None
|
||||
file_path: Optional[str] = None
|
||||
|
||||
class BatchResponse(BaseModel):
|
||||
"""Response model for batch processing."""
|
||||
batch_number: int
|
||||
total_batches: int
|
||||
processed_rows: int
|
||||
results_path: str
|
||||
message: str
|
||||
|
||||
def check_data_loaded():
|
||||
"""Check if data is loaded before running analytics."""
|
||||
@@ -103,6 +81,8 @@ async def startup_event():
|
||||
logger.error(f"Error during startup: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
@@ -112,6 +92,8 @@ async def root():
|
||||
logger.info(f"Root endpoint completed in {time.time() - start_time:.2f} seconds")
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
@@ -121,6 +103,8 @@ async def health_check():
|
||||
logger.info(f"Health check completed in {time.time() - start_time:.2f} seconds")
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@app.post("/analyze/keyword", response_model=AnalysisResponse)
|
||||
async def analyze_keyword():
|
||||
"""Run keyword-based salary transaction analysis."""
|
||||
@@ -141,6 +125,8 @@ async def analyze_keyword():
|
||||
logger.info(f"Keyword analysis endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@app.post("/analyze/consistent-amount", response_model=AnalysisResponse)
|
||||
async def analyze_consistent_amount():
|
||||
"""Run consistent amount transaction analysis."""
|
||||
@@ -161,6 +147,8 @@ async def analyze_consistent_amount():
|
||||
logger.info(f"Consistent amount analysis endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@app.post("/analyze/transaction-type", response_model=AnalysisResponse)
|
||||
async def analyze_transaction_type():
|
||||
"""Run transaction type analysis."""
|
||||
@@ -181,6 +169,8 @@ async def analyze_transaction_type():
|
||||
logger.info(f"Transaction type analysis endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@app.post("/generate/reports", response_model=AnalysisResponse)
|
||||
async def generate_reports(background_tasks: BackgroundTasks):
|
||||
"""Generate salary earner reports."""
|
||||
@@ -205,6 +195,8 @@ async def generate_reports(background_tasks: BackgroundTasks):
|
||||
logger.info(f"Report generation endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@app.post("/train/models", response_model=AnalysisResponse)
|
||||
async def train_models():
|
||||
"""Train salary prediction models."""
|
||||
@@ -224,6 +216,8 @@ async def train_models():
|
||||
logger.info(f"Model training endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@app.get("/download/{report_type}")
|
||||
async def download_report(report_type: str):
|
||||
"""Download generated reports."""
|
||||
@@ -264,6 +258,8 @@ async def download_report(report_type: str):
|
||||
logger.info(f"Download endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@app.post("/run/pipeline", response_model=AnalysisResponse)
|
||||
async def run_full_pipeline():
|
||||
"""Run the complete salary analytics pipeline."""
|
||||
@@ -288,6 +284,8 @@ async def run_full_pipeline():
|
||||
logger.info(f"Full pipeline endpoint failed after {time.time() - start_time:.2f} seconds")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@app.post("/load-data")
|
||||
async def load_data(source: str = "db", file: Optional[UploadFile] = File(None)):
|
||||
"""
|
||||
@@ -351,6 +349,8 @@ async def get_file_if_csv(source: str, file: Optional[UploadFile] = File(None)):
|
||||
raise HTTPException(status_code=400, detail="File must be provided when loading from CSV")
|
||||
return file
|
||||
|
||||
|
||||
|
||||
@app.post("/run/streaming-pipeline", response_model=List[BatchResponse])
|
||||
async def run_streaming_pipeline(
|
||||
source: str = "db",
|
||||
@@ -0,0 +1,17 @@
|
||||
from typing import Optional, Dict, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnalysisResponse(BaseModel):
|
||||
"""Response model for analysis endpoints."""
|
||||
message: str
|
||||
data: Optional[Dict] = None
|
||||
file_path: Optional[str] = None
|
||||
|
||||
class BatchResponse(BaseModel):
|
||||
"""Response model for batch processing."""
|
||||
batch_number: int
|
||||
total_batches: int
|
||||
processed_rows: int
|
||||
results_path: str
|
||||
message: str
|
||||
@@ -0,0 +1,11 @@
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi import FastAPI
|
||||
|
||||
def add_middlewares(app: FastAPI):
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
+2
-1
@@ -22,9 +22,10 @@ os.makedirs(PLOTS_DIR, exist_ok=True)
|
||||
os.makedirs(CSV_DIR, exist_ok=True)
|
||||
os.makedirs(MODEL_DIR, exist_ok=True)
|
||||
|
||||
|
||||
# Database Configuration
|
||||
DB_CONFIG = {
|
||||
"user": os.getenv("DB_USER"), # Default value as fallback
|
||||
"user": os.getenv("DB_USER"),
|
||||
"password": os.getenv("DB_PASSWORD"),
|
||||
"name": os.getenv("DB_NAME"),
|
||||
"port": os.getenv("DB_PORT"),
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Numeric, Boolean, func
|
||||
from sqlalchemy.orm import declarative_base, Session
|
||||
from datetime import datetime
|
||||
from app.utils.logger import logger
|
||||
from app.extensions import db
|
||||
|
||||
|
||||
class BatchResult(db.Model):
|
||||
__tablename__ = "salary_analytics_batch_results"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
batch_number = Column(Integer, nullable=False)
|
||||
total_batches = Column(Integer, nullable=False)
|
||||
processed_at = Column(DateTime, default=datetime.utcnow)
|
||||
accountid = Column(String, nullable=False)
|
||||
num_months = Column(Integer)
|
||||
least_inflow_6m = Column(Numeric)
|
||||
avg_monthly_salary = Column(Numeric)
|
||||
estimated_next_amount = Column(Numeric)
|
||||
estimated_next_date = Column(DateTime)
|
||||
is_45day_salary = Column(Boolean, default=False)
|
||||
is_2months_salary = Column(Boolean, default=False)
|
||||
status = Column(String, default="success")
|
||||
|
||||
|
||||
@classmethod
|
||||
def save_batch(cls, session: Session, batch_number, total_batches, results_df, status="success"):
|
||||
"""Save batch results into DB using ORM bulk insert."""
|
||||
try:
|
||||
results_df["batch_number"] = batch_number
|
||||
results_df["total_batches"] = total_batches
|
||||
results_df["processed_at"] = datetime.utcnow()
|
||||
results_df["status"] = status
|
||||
|
||||
# Normalize boolean columns
|
||||
results_df["is_45day_salary"] = results_df.get("45daysalary", False)
|
||||
results_df["is_2months_salary"] = results_df.get("2monthssalary", False)
|
||||
|
||||
# Convert to list of ORM objects
|
||||
records = [
|
||||
cls(**row)
|
||||
for row in results_df.to_dict("records")
|
||||
]
|
||||
|
||||
session.bulk_save_objects(records)
|
||||
session.commit()
|
||||
logger.info(f"Saved batch {batch_number} successfully.")
|
||||
return True
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Error saving batch {batch_number}: {str(e)}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_batch_status(cls, session: Session, batch_number: int):
|
||||
"""Return summary info about one batch."""
|
||||
try:
|
||||
result = (
|
||||
session.query(
|
||||
cls.batch_number,
|
||||
cls.total_batches,
|
||||
cls.processed_at,
|
||||
func.count().label("total_records"),
|
||||
func.sum(func.case((cls.status == "success", 1), else_=0)).label("successful_records"),
|
||||
func.sum(func.case((cls.status == "error", 1), else_=0)).label("failed_records"),
|
||||
)
|
||||
.filter(cls.batch_number == batch_number)
|
||||
.group_by(cls.batch_number, cls.total_batches, cls.processed_at)
|
||||
.order_by(cls.processed_at.desc())
|
||||
.first()
|
||||
)
|
||||
return dict(result._mapping) if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching batch {batch_number} status: {str(e)}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_all_batches(cls, session: Session):
|
||||
"""Return summaries for all batches."""
|
||||
try:
|
||||
results = (
|
||||
session.query(
|
||||
cls.batch_number,
|
||||
cls.total_batches,
|
||||
cls.processed_at,
|
||||
func.count().label("total_records"),
|
||||
func.sum(func.case((cls.status == "success", 1), else_=0)).label("successful_records"),
|
||||
func.sum(func.case((cls.status == "error", 1), else_=0)).label("failed_records"),
|
||||
)
|
||||
.group_by(cls.batch_number, cls.total_batches, cls.processed_at)
|
||||
.order_by(cls.batch_number)
|
||||
.all()
|
||||
)
|
||||
return [dict(r._mapping) for r in results]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching all batches: {str(e)}")
|
||||
return []
|
||||
@@ -3,7 +3,7 @@ Database operations module for salary analytics.
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
from .config import BATCH_RESULTS_TABLE
|
||||
from ..config import BATCH_RESULTS_TABLE
|
||||
from datetime import datetime
|
||||
from app.utils.logger import logger
|
||||
|
||||
Reference in New Issue
Block a user