import pandas as pd
import matplotlib.pyplot as plt
from db import get_db_connection
import matplotlib.dates as mdates
import math


conn = get_db_connection()
q = "SELECT username, estatus, edate, COUNT(*) AS activity_count FROM rssj.estatus GROUP BY username, estatus, edate ORDER BY username, estatus, edate;"
# Assume you've pulled the SQL results into a DataFrame
df = pd.read_sql_query(q, conn)
conn.close()


# Step 2: Fill Missing Dates per User/Estatus
df['edate'] = pd.to_datetime(df['edate'])
date_range = pd.date_range(start='2025-01-01', end=pd.Timestamp.today(), freq='D')
users = df['username'].unique()
statuses = df['estatus'].unique()
full_index = pd.MultiIndex.from_product([users, statuses, date_range], names=['username', 'estatus', 'edate'])
df_full = df.set_index(['username', 'estatus', 'edate']).reindex(full_index, fill_value=0).reset_index()

# Step 3: Filter Users with ≥ 10 Activities
totals = df_full.groupby('username')['activity_count'].sum()
active_users = totals[totals >= 10].index
filtered_df = df_full[df_full['username'].isin(active_users)]

# Step 4: Plot in 3-Column Grid with Stacked Area
unique_users = filtered_df['username'].unique()
cols = 3
rows = math.ceil(len(unique_users) / cols)

fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 4 * rows), sharex=True)
axes = axes.flatten()

for ax, username in zip(axes, unique_users):
    user_data = filtered_df[filtered_df['username'] == username]
    pivot = user_data.pivot(index='edate', columns='estatus', values='activity_count').fillna(0)
    pivot = pivot.loc[:, ~(pivot == 0).all()]

    pivot.plot.area(ax=ax, stacked=True, alpha=0.7)

    ax.set_title(username)
    ax.set_ylim(0, 100)
    ax.grid(axis='y', linestyle='--', alpha=0.3)
    ax.xaxis.set_major_locator(mdates.MonthLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
    ax.tick_params(axis='x', labelrotation=45)

# Turn off extra axes
for ax in axes[len(unique_users):]:
    ax.axis('off')

# Legend from last valid axis
handles, labels = ax.get_legend_handles_labels()
#fig.legend(handles, labels, title='Estatus', loc='upper center', ncol=min(5, len(labels)), bbox_to_anchor=(0.5, 1.02))

plt.tight_layout()
plt.xlabel("Date")
plt.show()