piano_trans / app.py
admin
sync ms
da9545e
raw
history blame
8.25 kB
import os
import re
import json
import torch
import shutil
import requests
import modelscope
import huggingface_hub
import gradio as gr
from tqdm import tqdm
from piano_transcription_inference import PianoTranscription, load_audio, sample_rate
from urllib.parse import urlparse
from convert import midi2xml, xml2abc, xml2mxl, xml2jpg
EN_US = os.getenv("LANG") != "zh_CN.UTF-8"
ZH2EN = {
"上传模式": "Uploading Mode",
"上传音频": "Upload an audio",
"下载 MIDI": "Download MIDI",
"下载 PDF 乐谱": "Download PDF score",
"下载 MusicXML": "Download MusicXML",
"下载 MXL": "Download MXL",
"ABC 记谱": "ABC notation",
"五线谱": "Staff",
"状态栏": "Status",
"请上传音频 100% 后再点提交": "Please make sure the audio is completely uploaded before clicking Submit",
"直链模式": "Direct Link Mode",
"输入音频 URL 直链": "Input audio direct link",
"下载音频": "Download audio",
"网易云音乐可直接输入非 VIP 歌曲页面链接自动解析": "For Netease Cloud music, you can directly input the non-VIP song page link",
"# 钢琴转谱工具": "# Piano Transcription Tool",
}
WEIGHTS_PATH = (
huggingface_hub.snapshot_download(
"Genius-Society/piano_trans",
cache_dir="./__pycache__",
)
if EN_US
else modelscope.snapshot_download(
"Genius-Society/piano_trans",
cache_dir="./__pycache__",
)
) + "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth"
def _L(zh_txt: str):
return ZH2EN[zh_txt] if EN_US else zh_txt
def clean_cache(cache_dir):
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.mkdir(cache_dir)
def download_audio(url: str, save_path: str):
# 发起流式请求
response = requests.get(url, stream=True)
response.raise_for_status()
# 获取文件总大小(字节),如果服务器未返回,则 total=0
total = int(response.headers.get("content-length", 0))
# 打开文件并创建 tqdm 进度条
with open(save_path, "wb") as file, tqdm(
desc=save_path, # 进度条前缀文字
total=total, # 总大小
unit="B", # 单位为字节
unit_scale=True, # 根据文件大小自动转换单位
unit_divisor=1024, # 1024 字节 = 1 KB
) as pbar:
# 以 8 KB 为块循环写入
for chunk in response.iter_content(chunk_size=8192):
if chunk: # 忽略 keep-alive 产生的空块
file.write(chunk)
pbar.update(len(chunk)) # 更新进度条
def is_url(s: str):
try:
# 解析字符串
result = urlparse(s)
# 检查scheme(如http, https)和netloc(域名)
return all([result.scheme, result.netloc])
except:
# 如果解析过程中发生异常,则返回False
return False
def audio2midi(audio_path: str, cache_dir: str):
audio, _ = load_audio(audio_path, sr=sample_rate, mono=True)
transcriptor = PianoTranscription(
device="cuda" if torch.cuda.is_available() else "cpu",
checkpoint_path=WEIGHTS_PATH,
)
midi_path = f"{cache_dir}/output.mid"
transcriptor.transcribe(audio, midi_path)
return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize()
def extract_fst_int(input_string: str):
match = re.search(r"\d+", input_string)
if match:
return str(int(match.group()))
else:
return ""
def music163_song_info(id: str):
detail_api = "https://music.163.com/api/v3/song/detail"
parm_dict = {"id": id, "c": str([{"id": id}]), "csrf_token": ""}
free = False
song_name = "获取歌曲失败"
response = requests.get(detail_api, params=parm_dict)
# 检查请求是否成功
if response.status_code == 200:
# 处理成功响应
data = json.loads(response.text)
if data and "songs" in data and data["songs"]:
fee = int(data["songs"][0]["fee"])
free = fee == 0 or fee == 8
song_name = str(data["songs"][0]["name"])
else:
song_name = "歌曲不存在"
else:
raise ConnectionError(f"错误: {response.status_code}, {response.text}")
return song_name, free
def upl_infer(audio_path: str, cache_dir="./__pycache__/mode1"):
status = "Success"
midi = pdf = xml = mxl = abc = jpg = None
try:
clean_cache(cache_dir)
midi, title = audio2midi(audio_path, cache_dir)
xml = midi2xml(midi, title)
abc = xml2abc(xml)
mxl = xml2mxl(xml)
pdf, jpg = xml2jpg(xml)
except Exception as e:
status = f"{e}"
return status, midi, pdf, xml, mxl, abc, jpg
def url_infer(song: str, cache_dir="./__pycache__/mode2"):
song_name = ""
status = "Success"
audio = midi = pdf = xml = mxl = abc = jpg = None
try:
clean_cache(cache_dir)
download_path = f"{cache_dir}/output.mp3"
if (is_url(song) and "163" in song and "?id=" in song) or song.isdigit():
song_id = extract_fst_int(song.split("?id=")[-1])
song = f"https://music.163.com/song/media/outer/url?id={song_id}.mp3"
song_name, free = music163_song_info(song_id)
if not free:
raise AttributeError("付费歌曲无法解析")
download_audio(song, download_path)
if not os.path.exists(download_path):
raise FileExistsError(f"{download_path} not exist")
midi, title = audio2midi(download_path, cache_dir)
if song_name:
title = song_name
audio = download_path
xml = midi2xml(midi, title)
abc = xml2abc(xml)
mxl = xml2mxl(xml)
pdf, jpg = xml2jpg(xml)
except Exception as e:
status = f"{e}"
return status, audio, midi, pdf, xml, mxl, abc, jpg
if __name__ == "__main__":
with gr.Blocks() as iface:
gr.Markdown(_L("# 钢琴转谱工具"))
with gr.Tab(_L("上传模式")):
gr.Interface(
fn=upl_infer,
inputs=gr.Audio(label=_L("上传音频"), type="filepath"),
outputs=[
gr.Textbox(label=_L("状态栏"), show_copy_button=True),
gr.File(label=_L("下载 MIDI")),
gr.File(label=_L("下载 PDF 乐谱")),
gr.File(label=_L("下载 MusicXML")),
gr.File(label=_L("下载 MXL")),
gr.Textbox(label=_L("ABC 记谱"), show_copy_button=True),
gr.Image(
label=_L("五线谱"),
type="filepath",
show_share_button=False,
),
],
title=_L("请上传音频 100% 后再点提交"),
flagging_mode="never",
)
if not EN_US:
with gr.Tab(_L("直链模式")):
gr.Interface(
fn=url_infer,
inputs=gr.Textbox(
label=_L("输入音频 URL 直链"),
placeholder="https://music.163.com/#/song?id=",
),
outputs=[
gr.Textbox(label=_L("状态栏"), show_copy_button=True),
gr.Audio(label=_L("下载音频"), type="filepath"),
gr.File(label=_L("下载 MIDI")),
gr.File(label=_L("下载 PDF 乐谱")),
gr.File(label=_L("下载 MusicXML")),
gr.File(label=_L("下载 MXL")),
gr.Textbox(label=_L("ABC 记谱"), show_copy_button=True),
gr.Image(
label=_L("五线谱"),
type="filepath",
show_share_button=False,
),
],
title=_L("网易云音乐可直接输入非 VIP 歌曲页面链接自动解析"),
examples=["1945798894", "1945798973", "1946098771"],
flagging_mode="never",
cache_examples=False,
)
iface.launch()