P1: rebuild
This commit is contained in:
140
task1/02_demand_correction.py
Normal file
140
task1/02_demand_correction.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Step 02: 截断回归修正 - 估计真实需求
|
||||
|
||||
输入: 01_clean.xlsx
|
||||
输出: 02_demand.xlsx
|
||||
|
||||
功能:
|
||||
1. 识别被容量截断的高需求站点
|
||||
2. 使用截断正态模型修正,估计真实需求 μ̃
|
||||
3. 输出修正前后的对比
|
||||
|
||||
核心假设:
|
||||
- 观测到的 μ_i 是真实需求在容量 C 下的截断观测
|
||||
- 真实需求服从正态分布 N(μ̃_i, σ̃_i²)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
from pathlib import Path
|
||||
|
||||
# 路径配置
|
||||
INPUT_PATH = Path(__file__).parent / "01_clean.xlsx"
|
||||
OUTPUT_PATH = Path(__file__).parent / "02_demand.xlsx"
|
||||
|
||||
# 模型参数
|
||||
C = 400 # 有效容量上限 (基于 μ_max = 396.6)
|
||||
P_TRUNC_THRESHOLD = 0.02 # 截断概率阈值 (调低以捕获更多潜在截断站点)
|
||||
|
||||
|
||||
def truncation_correction(mu: float, sigma: float, C: float = 400) -> tuple:
|
||||
"""
|
||||
截断回归修正
|
||||
|
||||
Args:
|
||||
mu: 观测均值
|
||||
sigma: 观测标准差
|
||||
C: 有效容量上限
|
||||
|
||||
Returns:
|
||||
(mu_tilde, p_trunc, is_corrected)
|
||||
- mu_tilde: 修正后的真实需求估计
|
||||
- p_trunc: 截断概率
|
||||
- is_corrected: 是否进行了修正
|
||||
"""
|
||||
if sigma <= 0 or np.isnan(sigma):
|
||||
return mu, 0.0, False
|
||||
|
||||
# 计算截断概率: P(D > C)
|
||||
z = (C - mu) / sigma
|
||||
p_trunc = 1 - norm.cdf(z)
|
||||
|
||||
if p_trunc < P_TRUNC_THRESHOLD:
|
||||
# 截断概率低于阈值,视为未截断
|
||||
return mu, p_trunc, False
|
||||
else:
|
||||
# 截断修正: 使用 Mills ratio 近似
|
||||
# E[D | D > C] = μ + σ * φ(z) / (1 - Φ(z))
|
||||
# 修正后: μ̃ ≈ μ * (1 + α * p_trunc)
|
||||
# 这里使用简化的线性修正
|
||||
correction_factor = 1 + 0.4 * p_trunc
|
||||
mu_tilde = mu * correction_factor
|
||||
return mu_tilde, p_trunc, True
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Step 02: 截断回归修正 - 估计真实需求")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 读取清洗后的数据
|
||||
print(f"\n[1] 读取输入: {INPUT_PATH}")
|
||||
df = pd.read_excel(INPUT_PATH)
|
||||
print(f" 读取 {len(df)} 条记录")
|
||||
|
||||
# 2. 显示参数
|
||||
print(f"\n[2] 模型参数:")
|
||||
print(f" 有效容量上限 C = {C}")
|
||||
print(f" 截断概率阈值 = {P_TRUNC_THRESHOLD}")
|
||||
|
||||
# 3. 应用截断修正
|
||||
print(f"\n[3] 应用截断修正...")
|
||||
results = []
|
||||
for _, row in df.iterrows():
|
||||
mu_tilde, p_trunc, is_corrected = truncation_correction(
|
||||
row['mu'], row['sigma'], C
|
||||
)
|
||||
results.append({
|
||||
'mu_tilde': mu_tilde,
|
||||
'p_trunc': p_trunc,
|
||||
'is_corrected': is_corrected
|
||||
})
|
||||
|
||||
df_result = pd.DataFrame(results)
|
||||
df['mu_tilde'] = df_result['mu_tilde']
|
||||
df['p_trunc'] = df_result['p_trunc']
|
||||
df['is_corrected'] = df_result['is_corrected']
|
||||
|
||||
# 4. 修正统计
|
||||
n_corrected = df['is_corrected'].sum()
|
||||
print(f"\n[4] 修正统计:")
|
||||
print(f" 被修正的站点数: {n_corrected} / {len(df)}")
|
||||
|
||||
if n_corrected > 0:
|
||||
corrected_sites = df[df['is_corrected']]
|
||||
print(f"\n 被修正站点详情:")
|
||||
print(f" {'site_id':<8} {'site_name':<35} {'μ':>8} {'σ':>8} {'p_trunc':>8} {'μ̃':>10} {'修正幅度':>10}")
|
||||
print(f" {'-'*8} {'-'*35} {'-'*8} {'-'*8} {'-'*8} {'-'*10} {'-'*10}")
|
||||
for _, row in corrected_sites.iterrows():
|
||||
correction_pct = (row['mu_tilde'] - row['mu']) / row['mu'] * 100
|
||||
print(f" {row['site_id']:<8} {row['site_name'][:35]:<35} {row['mu']:>8.1f} {row['sigma']:>8.1f} {row['p_trunc']:>8.3f} {row['mu_tilde']:>10.1f} {correction_pct:>9.1f}%")
|
||||
|
||||
# 5. 修正前后对比
|
||||
print(f"\n[5] 修正前后对比:")
|
||||
print(f" 修正前 μ 总和: {df['mu'].sum():.1f}")
|
||||
print(f" 修正后 μ̃ 总和: {df['mu_tilde'].sum():.1f}")
|
||||
print(f" 增幅: {(df['mu_tilde'].sum() / df['mu'].sum() - 1) * 100:.2f}%")
|
||||
|
||||
# 6. 保存输出
|
||||
print(f"\n[6] 保存输出: {OUTPUT_PATH}")
|
||||
# 选择输出列
|
||||
output_cols = ['site_id', 'site_name', 'lat', 'lon', 'visits_2019',
|
||||
'mu', 'sigma', 'mu_tilde', 'p_trunc', 'is_corrected']
|
||||
df[output_cols].to_excel(OUTPUT_PATH, index=False)
|
||||
print(f" 已保存 {len(df)} 条记录")
|
||||
|
||||
# 7. 输出预览
|
||||
print(f"\n[7] 输出数据预览 (μ 最高的10个站点):")
|
||||
top10 = df.nlargest(10, 'mu')[['site_id', 'site_name', 'mu', 'sigma', 'p_trunc', 'mu_tilde', 'is_corrected']]
|
||||
print(top10.to_string(index=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Step 02 完成")
|
||||
print("=" * 60)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user