142 lines
4.4 KiB
Python
142 lines
4.4 KiB
Python
"""
|
||
Step 02: 截断回归修正 - 估计真实需求
|
||
|
||
输入: 01_clean.xlsx
|
||
输出: 02_demand.xlsx
|
||
|
||
功能:
|
||
1. 识别被容量截断的高需求站点
|
||
2. 使用截断正态模型修正,估计真实需求 μ̃
|
||
3. 输出修正前后的对比
|
||
|
||
核心假设:
|
||
- 观测到的 μ_i 是真实需求在容量 C 下的截断观测
|
||
- 真实需求服从正态分布 N(μ̃_i, σ̃_i²)
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from scipy.stats import norm
|
||
from pathlib import Path
|
||
|
||
# 路径配置
|
||
INPUT_PATH = Path(__file__).parent / "01_clean.xlsx"
|
||
OUTPUT_PATH = Path(__file__).parent / "02_demand.xlsx"
|
||
|
||
# 模型参数
|
||
# 任务参数(按需可调)
|
||
C = 350 # 有效容量上限
|
||
P_TRUNC_THRESHOLD = 0.10 # 截断概率阈值 p_thresh
|
||
|
||
|
||
def truncation_correction(mu: float, sigma: float, C: float = 350) -> tuple:
|
||
"""
|
||
截断回归修正
|
||
|
||
Args:
|
||
mu: 观测均值
|
||
sigma: 观测标准差
|
||
C: 有效容量上限
|
||
|
||
Returns:
|
||
(mu_tilde, p_trunc, is_corrected)
|
||
- mu_tilde: 修正后的真实需求估计
|
||
- p_trunc: 截断概率
|
||
- is_corrected: 是否进行了修正
|
||
"""
|
||
if sigma <= 0 or np.isnan(sigma):
|
||
return mu, 0.0, False
|
||
|
||
# 计算截断概率: P(D > C)
|
||
z = (C - mu) / sigma
|
||
p_trunc = 1 - norm.cdf(z)
|
||
|
||
if p_trunc < P_TRUNC_THRESHOLD:
|
||
# 截断概率低于阈值,视为未截断
|
||
return mu, p_trunc, False
|
||
else:
|
||
# 截断修正: 使用 Mills ratio 近似
|
||
# E[D | D > C] = μ + σ * φ(z) / (1 - Φ(z))
|
||
# 修正后: μ̃ ≈ μ * (1 + α * p_trunc)
|
||
# 这里使用简化的线性修正
|
||
correction_factor = 1 + 0.1 * p_trunc
|
||
mu_tilde = mu * correction_factor
|
||
return mu_tilde, p_trunc, True
|
||
|
||
|
||
def main():
|
||
print("=" * 60)
|
||
print("Step 02: 截断回归修正 - 估计真实需求")
|
||
print("=" * 60)
|
||
|
||
# 1. 读取清洗后的数据
|
||
print(f"\n[1] 读取输入: {INPUT_PATH}")
|
||
df = pd.read_excel(INPUT_PATH)
|
||
print(f" 读取 {len(df)} 条记录")
|
||
|
||
# 2. 显示参数
|
||
print(f"\n[2] 模型参数:")
|
||
print(f" 有效容量上限 C = {C}")
|
||
print(f" 截断概率阈值 = {P_TRUNC_THRESHOLD}")
|
||
|
||
# 3. 应用截断修正
|
||
print(f"\n[3] 应用截断修正...")
|
||
results = []
|
||
for _, row in df.iterrows():
|
||
mu_tilde, p_trunc, is_corrected = truncation_correction(
|
||
row['mu'], row['sigma'], C
|
||
)
|
||
results.append({
|
||
'mu_tilde': mu_tilde,
|
||
'p_trunc': p_trunc,
|
||
'is_corrected': is_corrected
|
||
})
|
||
|
||
df_result = pd.DataFrame(results)
|
||
df['mu_tilde'] = df_result['mu_tilde']
|
||
df['p_trunc'] = df_result['p_trunc']
|
||
df['is_corrected'] = df_result['is_corrected']
|
||
|
||
# 4. 修正统计
|
||
n_corrected = df['is_corrected'].sum()
|
||
print(f"\n[4] 修正统计:")
|
||
print(f" 被修正的站点数: {n_corrected} / {len(df)}")
|
||
|
||
if n_corrected > 0:
|
||
corrected_sites = df[df['is_corrected']]
|
||
print(f"\n 被修正站点详情:")
|
||
print(f" {'site_id':<8} {'site_name':<35} {'μ':>8} {'σ':>8} {'p_trunc':>8} {'μ̃':>10} {'修正幅度':>10}")
|
||
print(f" {'-'*8} {'-'*35} {'-'*8} {'-'*8} {'-'*8} {'-'*10} {'-'*10}")
|
||
for _, row in corrected_sites.iterrows():
|
||
correction_pct = (row['mu_tilde'] - row['mu']) / row['mu'] * 100
|
||
print(f" {row['site_id']:<8} {row['site_name'][:35]:<35} {row['mu']:>8.1f} {row['sigma']:>8.1f} {row['p_trunc']:>8.3f} {row['mu_tilde']:>10.1f} {correction_pct:>9.1f}%")
|
||
|
||
# 5. 修正前后对比
|
||
print(f"\n[5] 修正前后对比:")
|
||
print(f" 修正前 μ 总和: {df['mu'].sum():.1f}")
|
||
print(f" 修正后 μ̃ 总和: {df['mu_tilde'].sum():.1f}")
|
||
print(f" 增幅: {(df['mu_tilde'].sum() / df['mu'].sum() - 1) * 100:.2f}%")
|
||
|
||
# 6. 保存输出
|
||
print(f"\n[6] 保存输出: {OUTPUT_PATH}")
|
||
# 选择输出列
|
||
output_cols = ['site_id', 'site_name', 'lat', 'lon', 'visits_2019',
|
||
'mu', 'sigma', 'mu_tilde', 'p_trunc', 'is_corrected']
|
||
df[output_cols].to_excel(OUTPUT_PATH, index=False)
|
||
print(f" 已保存 {len(df)} 条记录")
|
||
|
||
# 7. 输出预览
|
||
print(f"\n[7] 输出数据预览 (μ 最高的10个站点):")
|
||
top10 = df.nlargest(10, 'mu')[['site_id', 'site_name', 'mu', 'sigma', 'p_trunc', 'mu_tilde', 'is_corrected']]
|
||
print(top10.to_string(index=False))
|
||
|
||
print("\n" + "=" * 60)
|
||
print("Step 02 完成")
|
||
print("=" * 60)
|
||
|
||
return df
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|