Files
mcm-mfp/task1/02_demand_correction.py
2026-01-19 19:43:57 +08:00

142 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Step 02: 截断回归修正 - 估计真实需求
输入: 01_clean.xlsx
输出: 02_demand.xlsx
功能:
1. 识别被容量截断的高需求站点
2. 使用截断正态模型修正,估计真实需求 μ̃
3. 输出修正前后的对比
核心假设:
- 观测到的 μ_i 是真实需求在容量 C 下的截断观测
- 真实需求服从正态分布 N(μ̃_i, σ̃_i²)
"""
import pandas as pd
import numpy as np
from scipy.stats import norm
from pathlib import Path
# 路径配置
INPUT_PATH = Path(__file__).parent / "01_clean.xlsx"
OUTPUT_PATH = Path(__file__).parent / "02_demand.xlsx"
# 模型参数
# 任务参数(按需可调)
C = 350 # 有效容量上限
P_TRUNC_THRESHOLD = 0.10 # 截断概率阈值 p_thresh
def truncation_correction(mu: float, sigma: float, C: float = 350) -> tuple:
"""
截断回归修正
Args:
mu: 观测均值
sigma: 观测标准差
C: 有效容量上限
Returns:
(mu_tilde, p_trunc, is_corrected)
- mu_tilde: 修正后的真实需求估计
- p_trunc: 截断概率
- is_corrected: 是否进行了修正
"""
if sigma <= 0 or np.isnan(sigma):
return mu, 0.0, False
# 计算截断概率: P(D > C)
z = (C - mu) / sigma
p_trunc = 1 - norm.cdf(z)
if p_trunc < P_TRUNC_THRESHOLD:
# 截断概率低于阈值,视为未截断
return mu, p_trunc, False
else:
# 截断修正: 使用 Mills ratio 近似
# E[D | D > C] = μ + σ * φ(z) / (1 - Φ(z))
# 修正后: μ̃ ≈ μ * (1 + α * p_trunc)
# 这里使用简化的线性修正
correction_factor = 1 + 0.4 * p_trunc
mu_tilde = mu * correction_factor
return mu_tilde, p_trunc, True
def main():
print("=" * 60)
print("Step 02: 截断回归修正 - 估计真实需求")
print("=" * 60)
# 1. 读取清洗后的数据
print(f"\n[1] 读取输入: {INPUT_PATH}")
df = pd.read_excel(INPUT_PATH)
print(f" 读取 {len(df)} 条记录")
# 2. 显示参数
print(f"\n[2] 模型参数:")
print(f" 有效容量上限 C = {C}")
print(f" 截断概率阈值 = {P_TRUNC_THRESHOLD}")
# 3. 应用截断修正
print(f"\n[3] 应用截断修正...")
results = []
for _, row in df.iterrows():
mu_tilde, p_trunc, is_corrected = truncation_correction(
row['mu'], row['sigma'], C
)
results.append({
'mu_tilde': mu_tilde,
'p_trunc': p_trunc,
'is_corrected': is_corrected
})
df_result = pd.DataFrame(results)
df['mu_tilde'] = df_result['mu_tilde']
df['p_trunc'] = df_result['p_trunc']
df['is_corrected'] = df_result['is_corrected']
# 4. 修正统计
n_corrected = df['is_corrected'].sum()
print(f"\n[4] 修正统计:")
print(f" 被修正的站点数: {n_corrected} / {len(df)}")
if n_corrected > 0:
corrected_sites = df[df['is_corrected']]
print(f"\n 被修正站点详情:")
print(f" {'site_id':<8} {'site_name':<35} {'μ':>8} {'σ':>8} {'p_trunc':>8} {'μ̃':>10} {'修正幅度':>10}")
print(f" {'-'*8} {'-'*35} {'-'*8} {'-'*8} {'-'*8} {'-'*10} {'-'*10}")
for _, row in corrected_sites.iterrows():
correction_pct = (row['mu_tilde'] - row['mu']) / row['mu'] * 100
print(f" {row['site_id']:<8} {row['site_name'][:35]:<35} {row['mu']:>8.1f} {row['sigma']:>8.1f} {row['p_trunc']:>8.3f} {row['mu_tilde']:>10.1f} {correction_pct:>9.1f}%")
# 5. 修正前后对比
print(f"\n[5] 修正前后对比:")
print(f" 修正前 μ 总和: {df['mu'].sum():.1f}")
print(f" 修正后 μ̃ 总和: {df['mu_tilde'].sum():.1f}")
print(f" 增幅: {(df['mu_tilde'].sum() / df['mu'].sum() - 1) * 100:.2f}%")
# 6. 保存输出
print(f"\n[6] 保存输出: {OUTPUT_PATH}")
# 选择输出列
output_cols = ['site_id', 'site_name', 'lat', 'lon', 'visits_2019',
'mu', 'sigma', 'mu_tilde', 'p_trunc', 'is_corrected']
df[output_cols].to_excel(OUTPUT_PATH, index=False)
print(f" 已保存 {len(df)} 条记录")
# 7. 输出预览
print(f"\n[7] 输出数据预览 (μ 最高的10个站点):")
top10 = df.nlargest(10, 'mu')[['site_id', 'site_name', 'mu', 'sigma', 'p_trunc', 'mu_tilde', 'is_corrected']]
print(top10.to_string(index=False))
print("\n" + "=" * 60)
print("Step 02 完成")
print("=" * 60)
return df
if __name__ == "__main__":
main()