101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
from venv import logger
|
|
from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey
|
|
from sqlalchemy.orm import relationship
|
|
from app.extensions import db
|
|
import pandas as pd
|
|
|
|
class CustomerAccountTransactionHx(db.Model):
|
|
__tablename__ = "customer_account_transaction_hx"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
accountid = Column(String(64), nullable=False, index=True)
|
|
trx_type = Column(String(50), nullable=False)
|
|
amount = Column(Float, nullable=False)
|
|
description = Column(String(255))
|
|
customer_id = Column(String(64))
|
|
trx_start_date = Column(DateTime, nullable=False)
|
|
trx_end_date = Column(DateTime)
|
|
is_salary_related = Column(Integer, default=0)
|
|
is_consistent_amount = Column(Integer, default=0)
|
|
is_salary_type = Column(Integer, default=0)
|
|
|
|
|
|
|
|
@classmethod
|
|
def get_all(cls):
|
|
"""Fetch all transactions."""
|
|
return db.session.query(cls).all()
|
|
|
|
@classmethod
|
|
def get_rows_count(cls):
|
|
"""Return total number of transaction rows."""
|
|
try:
|
|
count = db.session.query(db.func.count(cls.id)).scalar()
|
|
return count
|
|
except Exception as e:
|
|
logger.error(f"Error getting row count: {str(e)}")
|
|
return None
|
|
|
|
@classmethod
|
|
def get_by_account(cls, accountid: str):
|
|
"""Fetch transactions for a given account."""
|
|
return db.session.query(cls).filter_by(accountid=accountid).all()
|
|
|
|
@classmethod
|
|
def get_accounts(cls, limit=None):
|
|
"""Fetch distinct account IDs."""
|
|
query = db.session.query(cls.accountid).distinct()
|
|
if limit:
|
|
query = query.limit(limit)
|
|
return [row.accountid for row in query.all()]
|
|
|
|
@classmethod
|
|
def insert_transaction(cls, **kwargs):
|
|
"""Insert a new transaction."""
|
|
trx = cls(**kwargs)
|
|
try:
|
|
db.session.add(trx)
|
|
db.session.commit()
|
|
except Exception as e:
|
|
logger.error(f"Error inserting transaction: {str(e)}")
|
|
return None
|
|
return trx
|
|
|
|
@classmethod
|
|
def bulk_insert(cls, transactions: list[dict]):
|
|
"""Insert multiple transactions at once."""
|
|
objs = [cls(**trx) for trx in transactions]
|
|
|
|
try:
|
|
db.session.bulk_save_objects(objs)
|
|
db.session.commit()
|
|
except Exception as e:
|
|
logger.error(f"Error in bulk insert: {str(e)}")
|
|
return None
|
|
return objs
|
|
|
|
@classmethod
|
|
def get_transactions_df(cls, accountids: list[str] = None):
|
|
"""Return a Pandas DataFrame for ML model preparation."""
|
|
query = db.session.query(cls)
|
|
if accountids:
|
|
query = query.filter(cls.accountid.in_(accountids))
|
|
rows = query.all()
|
|
|
|
|
|
df = pd.DataFrame([{
|
|
"id": trx.id,
|
|
"accountid": trx.accountid,
|
|
"trx_type": trx.trx_type,
|
|
"amount": trx.amount,
|
|
"description": trx.description,
|
|
"customer_id": trx.customer_id,
|
|
"trx_start_date": trx.trx_start_date,
|
|
"trx_end_date": trx.trx_end_date,
|
|
"is_salary_related": trx.is_salary_related,
|
|
"is_consistent_amount": trx.is_consistent_amount,
|
|
"is_salary_type": trx.is_salary_type,
|
|
} for trx in rows])
|
|
|
|
return df
|