Spaces:
Sleeping
Sleeping
from huggingface_hub import hf_hub_download | |
import streamlit as st | |
import pandas as pd | |
import joblib | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from scipy.interpolate import make_interp_spline | |
import numpy as np | |
import os | |
hf_token = os.getenv("HF_TOKEN") | |
def load_model(repo_id, filename): | |
# 使用环境变量中的HF_TOKEN下载模型文件 | |
model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=hf_token) | |
# 使用joblib加载模型 | |
model = joblib.load(model_path) | |
return model | |
# 设置Seaborn的风格 | |
sns.set(style="whitegrid") | |
def preprocess_and_predict(model, data, match_id, features): | |
# 筛选出特定match_id的数据 | |
specific_match_data = data[data['match_id'] == match_id] | |
specific_match_data = specific_match_data.drop_duplicates(subset=['elapsed_time'], keep='first') | |
specific_match_data = specific_match_data[specific_match_data['server'] != 0] | |
if specific_match_data.empty: | |
st.write(f"No data found for match_id {match_id}.") | |
return None, None | |
else: | |
specific_match_data = specific_match_data.sort_values('elapsed_time') | |
# 预处理数据集 | |
X_specific_match = specific_match_data[features] | |
# 使用模型进行概率预测 | |
positive_class_probabilities = model.predict_proba(X_specific_match)[:, 1] | |
return specific_match_data, positive_class_probabilities | |
def plot_results(specific_match_data, positive_class_probabilities, match_id, show_true_label=True, show_predicted_probability=True, show_momentum=True): | |
plt.figure(figsize=(14, 7)) | |
if show_predicted_probability: | |
# 绘制预测概率的线图 | |
sns.lineplot( | |
x=specific_match_data['elapsed_time'], | |
y=positive_class_probabilities, | |
marker='o', | |
linestyle='-', | |
color='blue', | |
label='Predicted Probability' | |
) | |
if show_true_label: | |
# 绘制真实标签的散点图 | |
sns.scatterplot( | |
x=specific_match_data['elapsed_time'], | |
y=(specific_match_data['point_victor'] - 1), | |
color='red', | |
label='True Label', | |
s=60 | |
) | |
if show_momentum: | |
# 修改动能计算逻辑,以确保长度一致 | |
adjusted_probabilities = [] | |
for i in range(1, len(positive_class_probabilities)): | |
# 使用当前点和前一个点的概率值,假定下一个点的概率为0.5 | |
if i < len(positive_class_probabilities) - 1: | |
next_prob = 0.5 # 对于除最后一个点外的所有点,假定下一个点的概率为0.5 | |
else: | |
next_prob = positive_class_probabilities[i] # 对于最后一个点,使用其本身的概率 | |
adjusted_probabilities.append((positive_class_probabilities[i-1] + positive_class_probabilities[i] + next_prob) / 3) | |
# 对于第一个点,我们可以选择使用它自己的概率,因为没有前一个点 | |
adjusted_probabilities.insert(0, (positive_class_probabilities[0] + 0.5) / 2) # 在开始处插入 | |
# 确保adjusted_probabilities的长度与specific_match_data['elapsed_time']一致 | |
momentum = np.array(adjusted_probabilities[:len(specific_match_data)]) | |
X_smooth = np.linspace(specific_match_data['elapsed_time'].min(), specific_match_data['elapsed_time'].max(), 300) | |
spline = make_interp_spline(specific_match_data['elapsed_time'], momentum, k=3) | |
momentum_smooth = spline(X_smooth) | |
plt.fill_between(X_smooth, momentum_smooth, color='green', alpha=0.3) | |
plt.plot(X_smooth, momentum_smooth, color='green', label='Momentum') | |
# 标记set_no和game_no变化的时刻 | |
for i in range(1, len(specific_match_data)): | |
if specific_match_data['game_no'].iloc[i] != specific_match_data['game_no'].iloc[i-1]: | |
plt.axvline( | |
x=specific_match_data['elapsed_time'].iloc[i], | |
color='gray', | |
linestyle='--', | |
lw=2 | |
) | |
if specific_match_data['set_no'].iloc[i] != specific_match_data['set_no'].iloc[i-1]: | |
plt.axvline( | |
x=specific_match_data['elapsed_time'].iloc[i], | |
color='red', | |
linestyle='-.', | |
lw=2 | |
) | |
plt.title(f'Predicted Probability, Momentum, and True Label Over Time for Match {match_id}') | |
plt.xlabel('Elapsed Time') | |
plt.ylabel('Probability / True Label / Momentum') | |
plt.grid(True) | |
plt.legend() | |
st.pyplot(plt) | |
def main(): | |
st.title('Momentum Catcher') | |
st.markdown(""" | |
To get started, you can find sample data available for download at | |
[Hugging Face Spaces](https://huggingface.co/spaces/Nagi-ovo/Tennis-Momentum-Tracker/tree/main/data). | |
This data can be used directly in this application to analyze tennis match momentum. | |
""") | |
uploaded_file = st.file_uploader("Upload your input CSV data", type="csv") | |
# match_id_input = st.text_input("Enter the match_id you want to analyze", "2023-wimbledon-1301") | |
if uploaded_file is not None: | |
new_data = pd.read_csv(uploaded_file) | |
new_data.dropna() | |
# 新增:提取所有唯一的match_id | |
unique_match_ids = new_data['match_id'].unique() | |
# 新增:让用户从所有match_id中选择一个 | |
match_id_input = st.selectbox("Select the match_id you want to analyze", unique_match_ids) | |
# 模型信息 | |
repo_id = "Nagi-ovo/Momentum-XGboost" | |
filename = "xgboost.pkl" | |
# 加载模型 | |
xgb_model = load_model(repo_id, filename) | |
features = ['PAI_diff', 'normalized_rally', 'is_game_point'] | |
new_data['is_game_point'] = abs(new_data['p1_facing_game_point'] - new_data['p2_facing_game_point']) | |
new_data['PAI_diff'] = new_data['p1_PAI'] - new_data['p2_PAI'] | |
specific_match_data, positive_class_probabilities = preprocess_and_predict(xgb_model, new_data, match_id_input, features) | |
if specific_match_data is not None: | |
# 允许用户选择是否要观察特定的set和game | |
observe_specific = st.checkbox("Observe specific set(s) and game(s)", False) | |
if observe_specific: | |
# 用户选择观察多盘 | |
unique_sets = specific_match_data['set_no'].unique() | |
selected_sets = st.multiselect('Select set(s) to observe', unique_sets, default=unique_sets) | |
# 基于选定的set_no,选择观察一局或多局 | |
if selected_sets: | |
filtered_data = specific_match_data[specific_match_data['set_no'].isin(selected_sets)] | |
unique_games = filtered_data['game_no'].unique() | |
selected_games = st.multiselect('Select games to observe (default is all games)', unique_games, default=unique_games) | |
# 进一步筛选数据以仅包含选定的game_no | |
if selected_games: | |
filtered_data = filtered_data[filtered_data['game_no'].isin(selected_games)] | |
else: | |
filtered_data = specific_match_data # 如果没有选择任何set,显示全部数据 | |
else: | |
filtered_data = specific_match_data | |
selected_games = specific_match_data['game_no'].unique() # 默认选择所有游戏 | |
# 显示选项复选框 | |
show_true_label = st.checkbox("Show True Label", True) | |
show_predicted_probability = st.checkbox("Show Predicted Probability", True) | |
show_momentum = st.checkbox("Show Momentum", True) | |
# 如果筛选后的数据不为空,则绘制图表 | |
if not filtered_data.empty: | |
# 计算筛选数据的预测概率 | |
filtered_positive_class_probabilities = xgb_model.predict_proba(filtered_data[features])[:, 1] | |
plot_results(filtered_data, filtered_positive_class_probabilities, match_id_input, show_true_label, show_predicted_probability, show_momentum) | |
else: | |
st.write("No data available for the selected set(s) and games. Please select again.") | |
if __name__ == '__main__': | |
main() | |