Spaces:
Running
Running
import os | |
import re | |
import json | |
import torch | |
import shutil | |
import requests | |
import modelscope | |
import huggingface_hub | |
import gradio as gr | |
from tqdm import tqdm | |
from piano_transcription_inference import PianoTranscription, load_audio, sample_rate | |
from urllib.parse import urlparse | |
from convert import midi2xml, xml2abc, xml2mxl, xml2jpg | |
EN_US = os.getenv("LANG") != "zh_CN.UTF-8" | |
ZH2EN = { | |
"上传模式": "Uploading Mode", | |
"上传音频": "Upload an audio", | |
"下载 MIDI": "Download MIDI", | |
"下载 PDF 乐谱": "Download PDF score", | |
"下载 MusicXML": "Download MusicXML", | |
"下载 MXL": "Download MXL", | |
"ABC 记谱": "ABC notation", | |
"五线谱": "Staff", | |
"状态栏": "Status", | |
"请上传音频 100% 后再点提交": "Please make sure the audio is completely uploaded before clicking Submit", | |
"直链模式": "Direct Link Mode", | |
"输入音频 URL 直链": "Input audio direct link", | |
"下载音频": "Download audio", | |
"网易云音乐可直接输入非 VIP 歌曲页面链接自动解析": "For Netease Cloud music, you can directly input the non-VIP song page link", | |
"# 钢琴转谱工具": "# Piano Transcription Tool", | |
} | |
WEIGHTS_PATH = ( | |
huggingface_hub.snapshot_download( | |
"Genius-Society/piano_trans", | |
cache_dir="./__pycache__", | |
) | |
if EN_US | |
else modelscope.snapshot_download( | |
"Genius-Society/piano_trans", | |
cache_dir="./__pycache__", | |
) | |
) + "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth" | |
def _L(zh_txt: str): | |
return ZH2EN[zh_txt] if EN_US else zh_txt | |
def clean_cache(cache_dir): | |
if os.path.exists(cache_dir): | |
shutil.rmtree(cache_dir) | |
os.mkdir(cache_dir) | |
def download_audio(url: str, save_path: str): | |
# 发起流式请求 | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
# 获取文件总大小(字节),如果服务器未返回,则 total=0 | |
total = int(response.headers.get("content-length", 0)) | |
# 打开文件并创建 tqdm 进度条 | |
with open(save_path, "wb") as file, tqdm( | |
desc=save_path, # 进度条前缀文字 | |
total=total, # 总大小 | |
unit="B", # 单位为字节 | |
unit_scale=True, # 根据文件大小自动转换单位 | |
unit_divisor=1024, # 1024 字节 = 1 KB | |
) as pbar: | |
# 以 8 KB 为块循环写入 | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: # 忽略 keep-alive 产生的空块 | |
file.write(chunk) | |
pbar.update(len(chunk)) # 更新进度条 | |
def is_url(s: str): | |
try: | |
# 解析字符串 | |
result = urlparse(s) | |
# 检查scheme(如http, https)和netloc(域名) | |
return all([result.scheme, result.netloc]) | |
except: | |
# 如果解析过程中发生异常,则返回False | |
return False | |
def audio2midi(audio_path: str, cache_dir: str): | |
audio, _ = load_audio(audio_path, sr=sample_rate, mono=True) | |
transcriptor = PianoTranscription( | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
checkpoint_path=WEIGHTS_PATH, | |
) | |
midi_path = f"{cache_dir}/output.mid" | |
transcriptor.transcribe(audio, midi_path) | |
return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize() | |
def extract_fst_int(input_string: str): | |
match = re.search(r"\d+", input_string) | |
if match: | |
return str(int(match.group())) | |
else: | |
return "" | |
def music163_song_info(id: str): | |
detail_api = "https://music.163.com/api/v3/song/detail" | |
parm_dict = {"id": id, "c": str([{"id": id}]), "csrf_token": ""} | |
free = False | |
song_name = "获取歌曲失败" | |
response = requests.get(detail_api, params=parm_dict) | |
# 检查请求是否成功 | |
if response.status_code == 200: | |
# 处理成功响应 | |
data = json.loads(response.text) | |
if data and "songs" in data and data["songs"]: | |
fee = int(data["songs"][0]["fee"]) | |
free = fee == 0 or fee == 8 | |
song_name = str(data["songs"][0]["name"]) | |
else: | |
song_name = "歌曲不存在" | |
else: | |
raise ConnectionError(f"错误: {response.status_code}, {response.text}") | |
return song_name, free | |
def upl_infer(audio_path: str, cache_dir="./__pycache__/mode1"): | |
status = "Success" | |
midi = pdf = xml = mxl = abc = jpg = None | |
try: | |
clean_cache(cache_dir) | |
midi, title = audio2midi(audio_path, cache_dir) | |
xml = midi2xml(midi, title) | |
abc = xml2abc(xml) | |
mxl = xml2mxl(xml) | |
pdf, jpg = xml2jpg(xml) | |
except Exception as e: | |
status = f"{e}" | |
return status, midi, pdf, xml, mxl, abc, jpg | |
def url_infer(song: str, cache_dir="./__pycache__/mode2"): | |
song_name = "" | |
status = "Success" | |
audio = midi = pdf = xml = mxl = abc = jpg = None | |
try: | |
clean_cache(cache_dir) | |
download_path = f"{cache_dir}/output.mp3" | |
if (is_url(song) and "163" in song and "?id=" in song) or song.isdigit(): | |
song_id = extract_fst_int(song.split("?id=")[-1]) | |
song = f"https://music.163.com/song/media/outer/url?id={song_id}.mp3" | |
song_name, free = music163_song_info(song_id) | |
if not free: | |
raise AttributeError("付费歌曲无法解析") | |
download_audio(song, download_path) | |
if not os.path.exists(download_path): | |
raise FileExistsError(f"{download_path} not exist") | |
midi, title = audio2midi(download_path, cache_dir) | |
if song_name: | |
title = song_name | |
audio = download_path | |
xml = midi2xml(midi, title) | |
abc = xml2abc(xml) | |
mxl = xml2mxl(xml) | |
pdf, jpg = xml2jpg(xml) | |
except Exception as e: | |
status = f"{e}" | |
return status, audio, midi, pdf, xml, mxl, abc, jpg | |
if __name__ == "__main__": | |
with gr.Blocks() as iface: | |
gr.Markdown(_L("# 钢琴转谱工具")) | |
with gr.Tab(_L("上传模式")): | |
gr.Interface( | |
fn=upl_infer, | |
inputs=gr.Audio(label=_L("上传音频"), type="filepath"), | |
outputs=[ | |
gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
gr.File(label=_L("下载 MIDI")), | |
gr.File(label=_L("下载 PDF 乐谱")), | |
gr.File(label=_L("下载 MusicXML")), | |
gr.File(label=_L("下载 MXL")), | |
gr.Textbox(label=_L("ABC 记谱"), show_copy_button=True), | |
gr.Image( | |
label=_L("五线谱"), | |
type="filepath", | |
show_share_button=False, | |
), | |
], | |
title=_L("请上传音频 100% 后再点提交"), | |
flagging_mode="never", | |
) | |
if not EN_US: | |
with gr.Tab(_L("直链模式")): | |
gr.Interface( | |
fn=url_infer, | |
inputs=gr.Textbox( | |
label=_L("输入音频 URL 直链"), | |
placeholder="https://music.163.com/#/song?id=", | |
), | |
outputs=[ | |
gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
gr.Audio(label=_L("下载音频"), type="filepath"), | |
gr.File(label=_L("下载 MIDI")), | |
gr.File(label=_L("下载 PDF 乐谱")), | |
gr.File(label=_L("下载 MusicXML")), | |
gr.File(label=_L("下载 MXL")), | |
gr.Textbox(label=_L("ABC 记谱"), show_copy_button=True), | |
gr.Image( | |
label=_L("五线谱"), | |
type="filepath", | |
show_share_button=False, | |
), | |
], | |
title=_L("网易云音乐可直接输入非 VIP 歌曲页面链接自动解析"), | |
examples=["1945798894", "1945798973", "1946098771"], | |
flagging_mode="never", | |
cache_examples=False, | |
) | |
iface.launch() | |