P1: visual

This commit is contained in:
2026-01-17 23:03:26 +08:00
parent 5fa64a8720
commit 6e3072155b
4 changed files with 234 additions and 0 deletions

View File

@@ -21,3 +21,9 @@ Optimize a 365-day schedule with at most 2 visits per day and minimum gap constr
- `python3 scheduling_optimization.py --days 365 --daily-capacity 2 --gap-min 14`
- Outputs are written to `data/` (e.g., `data/schedule_optimized_kmin6.8_gap14.csv`), using `data/kmin_effectiveness_data.csv` as the frequency source.
### Visualization (Plan A)
- `python3 visualize_schedule.py`
- Outputs: `data/schedule_barcode_*.png` and `data/schedule_gap_deviation_*.png`
- Site label rule: remove first 4 chars, then take 12 chars.

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

228
visualize_schedule.py Normal file
View File

@@ -0,0 +1,228 @@
"""
Schedule Visualization (Plan A)
Produces:
1) Barcode/Raster plot: site vs day visits
2) Gap deviation plot: (gap - ideal_gap) grouped by frequency
Inputs:
- data/schedule_long_*.csv from scheduling_optimization.py
- data/kmin_effectiveness_sites.csv (site metadata)
Site short name rule (per user request):
- remove first 4 characters, then take next 12 characters.
"""
from __future__ import annotations
import argparse
import glob
import os
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
_HAS_MPL = True
except ModuleNotFoundError:
plt = None
_HAS_MPL = False
OUTPUT_DIR = "data"
DEFAULT_SITES_CSV = os.path.join(OUTPUT_DIR, "kmin_effectiveness_sites.csv")
def short_site_name(name: str) -> str:
s = (name or "").strip()
if len(s) <= 4:
return s[:12]
return s[4:][:12]
def find_latest_file(pattern: str) -> str:
matches = glob.glob(pattern)
if not matches:
raise FileNotFoundError(f"No files match: {pattern}")
matches.sort(key=lambda p: os.path.getmtime(p), reverse=True)
return matches[0]
def stem_from_filename(path: str) -> str:
base = os.path.basename(path)
for prefix in ("schedule_long_", "schedule_optimized_", "site_visits_"):
if base.startswith(prefix) and base.endswith(".csv"):
return base[len(prefix) : -len(".csv")]
if base.endswith(".csv"):
return base[:-len(".csv")]
return base
def load_schedule_long(path: str) -> pd.DataFrame:
df = pd.read_csv(path)
if "day" not in df.columns or "site_idx" not in df.columns:
raise ValueError(f"Expected columns day, site_idx in {path}")
df["day"] = df["day"].astype(int)
df["site_idx"] = df["site_idx"].astype(int)
return df
def load_sites(path: str) -> pd.DataFrame:
df = pd.read_csv(path)
needed = {"site_idx", "site_name"}
if not needed.issubset(df.columns):
raise ValueError(f"Expected columns {sorted(needed)} in {path}")
df = df.copy()
df["site_idx"] = df["site_idx"].astype(int)
df["site_name"] = df["site_name"].astype(str)
if "total_demand" in df.columns:
df["total_demand"] = pd.to_numeric(df["total_demand"], errors="coerce")
return df
def compute_gaps(schedule_long: pd.DataFrame) -> pd.DataFrame:
gaps_rows: List[Dict[str, float]] = []
for site_idx, g in schedule_long.groupby("site_idx"):
days = sorted(g["day"].tolist())
if len(days) < 2:
continue
for a, b in zip(days, days[1:]):
gaps_rows.append({"site_idx": int(site_idx), "gap": int(b - a)})
return pd.DataFrame(gaps_rows)
def plot_barcode(
schedule_long: pd.DataFrame,
sites: pd.DataFrame,
*,
days: int,
sort_by: str,
out_path: str,
) -> None:
if not _HAS_MPL:
raise RuntimeError("Missing dependency: matplotlib (cannot plot).")
sites2 = sites.copy()
sites2["short_name"] = sites2["site_name"].map(short_site_name)
if sort_by == "site_idx":
sites2 = sites2.sort_values(["site_idx"])
elif sort_by == "total_demand":
if "total_demand" not in sites2.columns:
raise ValueError("sites CSV missing total_demand; cannot sort by total_demand")
sites2 = sites2.sort_values(["total_demand", "site_idx"], ascending=[False, True])
else:
raise ValueError("sort_by must be 'site_idx' or 'total_demand'")
order = sites2["site_idx"].tolist()
y_pos = {idx: i for i, idx in enumerate(order)}
y = schedule_long["site_idx"].map(y_pos).to_numpy()
x = schedule_long["day"].to_numpy()
fig, ax = plt.subplots(figsize=(14, 8))
ax.scatter(x, y, s=18, marker="|", linewidths=1.5, alpha=0.7, color="black")
ax.set_xlim(1, days)
ax.set_ylim(-1, len(order))
ax.set_xlabel("Day (1..365)")
ax.set_ylabel("Sites (sorted)")
ax.set_title("Schedule Barcode (Visits over 365 days)")
ax.grid(True, axis="x", alpha=0.15)
# Show a small subset of y tick labels for readability.
step = max(1, len(order) // 12)
tick_idx = list(range(0, len(order), step))
tick_labels = sites2["short_name"].tolist()
ax.set_yticks(tick_idx)
ax.set_yticklabels([tick_labels[i] for i in tick_idx], fontsize=9)
fig.tight_layout()
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
fig.savefig(out_path, dpi=160)
plt.close(fig)
def plot_gap_deviation(
schedule_long: pd.DataFrame,
sites: pd.DataFrame,
*,
days: int,
gap_min: int,
out_path: str,
) -> None:
if not _HAS_MPL:
raise RuntimeError("Missing dependency: matplotlib (cannot plot).")
# Infer f_i from schedule itself (more robust than requiring the frequency CSV).
freq = schedule_long.groupby("site_idx")["day"].size().rename("f_i").reset_index()
gaps = compute_gaps(schedule_long)
df = gaps.merge(freq, on="site_idx", how="left").merge(sites[["site_idx", "site_name"]], on="site_idx", how="left")
df["ideal_gap"] = df["f_i"].apply(lambda f: (days / f) if f and f > 0 else np.nan)
df["dev"] = df["gap"] - df["ideal_gap"]
# Group deviations by frequency for a boxplot.
freq_levels = sorted(df["f_i"].dropna().unique().astype(int).tolist())
data = [df.loc[df["f_i"] == f, "dev"].dropna().to_numpy() for f in freq_levels]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), gridspec_kw={"width_ratios": [2.2, 1.0]})
ax1.boxplot(data, labels=[str(f) for f in freq_levels], showfliers=False)
ax1.axhline(0, color="black", lw=1, alpha=0.6)
ax1.set_xlabel("Frequency f_i (visits/year)")
ax1.set_ylabel("Gap deviation (gap - 365/f_i) in days")
ax1.set_title("Gap Regularity by Frequency")
ax1.grid(True, axis="y", alpha=0.2)
# Quick diagnostics: min gap violations and deviation histogram.
violations = int((df["gap"] < gap_min).sum())
ax2.hist(df["dev"].dropna().to_numpy(), bins=20, color="tab:blue", alpha=0.85)
ax2.axvline(0, color="black", lw=1, alpha=0.6)
ax2.set_xlabel("Deviation (days)")
ax2.set_ylabel("Count")
ax2.set_title(f"Deviation Histogram\nGap_min<{gap_min} violations: {violations}")
ax2.grid(True, axis="y", alpha=0.2)
fig.tight_layout()
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
fig.savefig(out_path, dpi=160)
plt.close(fig)
def main() -> None:
parser = argparse.ArgumentParser(description="Visualize optimized schedule outputs.")
parser.add_argument(
"--schedule-long",
default=None,
help="Path to data/schedule_long_*.csv. If omitted, uses the latest matching file in data/.",
)
parser.add_argument("--sites-csv", default=DEFAULT_SITES_CSV)
parser.add_argument("--days", type=int, default=365)
parser.add_argument("--gap-min", type=int, default=14)
parser.add_argument("--sort-by", choices=["site_idx", "total_demand"], default="total_demand")
args = parser.parse_args()
if args.schedule_long is None:
args.schedule_long = find_latest_file(os.path.join(OUTPUT_DIR, "schedule_long_*.csv"))
schedule_long = load_schedule_long(args.schedule_long)
sites = load_sites(args.sites_csv)
stem = stem_from_filename(args.schedule_long)
out_barcode = os.path.join(OUTPUT_DIR, f"schedule_barcode_{stem}.png")
out_gaps = os.path.join(OUTPUT_DIR, f"schedule_gap_deviation_{stem}.png")
plot_barcode(schedule_long, sites, days=args.days, sort_by=args.sort_by, out_path=out_barcode)
plot_gap_deviation(schedule_long, sites, days=args.days, gap_min=args.gap_min, out_path=out_gaps)
print(f"Saved: {out_barcode}")
print(f"Saved: {out_gaps}")
if __name__ == "__main__":
main()