SohomToom commited on
Commit
cad5d5d
·
verified ·
1 Parent(s): ab0bdb4

Upload 91 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. MeloTTS/.github/workflows/pypi.yml +40 -0
  3. MeloTTS/.gitignore +11 -0
  4. MeloTTS/Dockerfile +13 -0
  5. MeloTTS/LICENSE +19 -0
  6. MeloTTS/README.md +62 -0
  7. MeloTTS/docs/install.md +230 -0
  8. MeloTTS/docs/quick_use.md +45 -0
  9. MeloTTS/docs/training.md +37 -0
  10. MeloTTS/logo.png +3 -0
  11. MeloTTS/melo/__init__.py +0 -0
  12. MeloTTS/melo/api.py +135 -0
  13. MeloTTS/melo/app.py +61 -0
  14. MeloTTS/melo/attentions.py +459 -0
  15. MeloTTS/melo/commons.py +160 -0
  16. MeloTTS/melo/configs/config.json +94 -0
  17. MeloTTS/melo/data/example/metadata.list +20 -0
  18. MeloTTS/melo/data_utils.py +413 -0
  19. MeloTTS/melo/download_utils.py +67 -0
  20. MeloTTS/melo/infer.py +25 -0
  21. MeloTTS/melo/init_downloads.py +14 -0
  22. MeloTTS/melo/losses.py +58 -0
  23. MeloTTS/melo/main.py +36 -0
  24. MeloTTS/melo/mel_processing.py +174 -0
  25. MeloTTS/melo/models.py +1030 -0
  26. MeloTTS/melo/modules.py +598 -0
  27. MeloTTS/melo/monotonic_align/__init__.py +16 -0
  28. MeloTTS/melo/monotonic_align/core.py +46 -0
  29. MeloTTS/melo/preprocess_text.py +135 -0
  30. MeloTTS/melo/split_utils.py +174 -0
  31. MeloTTS/melo/text/__init__.py +35 -0
  32. MeloTTS/melo/text/chinese.py +199 -0
  33. MeloTTS/melo/text/chinese_bert.py +107 -0
  34. MeloTTS/melo/text/chinese_mix.py +253 -0
  35. MeloTTS/melo/text/cleaner.py +36 -0
  36. MeloTTS/melo/text/cleaner_multiling.py +110 -0
  37. MeloTTS/melo/text/cmudict.rep +0 -0
  38. MeloTTS/melo/text/cmudict_cache.pickle +3 -0
  39. MeloTTS/melo/text/english.py +284 -0
  40. MeloTTS/melo/text/english_bert.py +39 -0
  41. MeloTTS/melo/text/english_utils/__init__.py +0 -0
  42. MeloTTS/melo/text/english_utils/abbreviations.py +35 -0
  43. MeloTTS/melo/text/english_utils/number_norm.py +97 -0
  44. MeloTTS/melo/text/english_utils/time_norm.py +47 -0
  45. MeloTTS/melo/text/es_phonemizer/__init__.py +0 -0
  46. MeloTTS/melo/text/es_phonemizer/base.py +140 -0
  47. MeloTTS/melo/text/es_phonemizer/cleaner.py +109 -0
  48. MeloTTS/melo/text/es_phonemizer/es_symbols.json +79 -0
  49. MeloTTS/melo/text/es_phonemizer/es_symbols.txt +1 -0
  50. MeloTTS/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
.gitattributes CHANGED
@@ -40,3 +40,5 @@ openvoice/resources/example_reference.mp3 filter=lfs diff=lfs merge=lfs -text
40
  openvoice/resources/openvoicelogo.jpg filter=lfs diff=lfs merge=lfs -text
41
  openvoice/resources/tts-guide.png filter=lfs diff=lfs merge=lfs -text
42
  openvoice/resources/voice-clone-guide.png filter=lfs diff=lfs merge=lfs -text
 
 
 
40
  openvoice/resources/openvoicelogo.jpg filter=lfs diff=lfs merge=lfs -text
41
  openvoice/resources/tts-guide.png filter=lfs diff=lfs merge=lfs -text
42
  openvoice/resources/voice-clone-guide.png filter=lfs diff=lfs merge=lfs -text
43
+ MeloTTS/logo.png filter=lfs diff=lfs merge=lfs -text
44
+ MeloTTS/melo/text/fr_phonemizer/example_ipa.txt filter=lfs diff=lfs merge=lfs -text
MeloTTS/.github/workflows/pypi.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ permissions:
16
+ contents: read
17
+
18
+ jobs:
19
+ deploy:
20
+
21
+ runs-on: ubuntu-latest
22
+
23
+ steps:
24
+ - uses: actions/checkout@v3
25
+ - name: Set up Python
26
+ uses: actions/setup-python@v3
27
+ with:
28
+ python-version: '3.x'
29
+ - name: Install dependencies
30
+ run: |
31
+ python -m pip install --upgrade pip
32
+ python -m ensurepip --upgrade
33
+ pip install build
34
+ - name: Build package
35
+ run: python -m build
36
+ - name: Publish package
37
+ uses: pypa/gh-action-pypi-publish@release/v1.8
38
+ with:
39
+ user: __token__
40
+ password: ${{ secrets.PYPI_API_TOKEN }}
MeloTTS/.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .ipynb_checkpoints/
3
+ basetts_outputs_use_bert/
4
+ basetts_outputs/
5
+ multilingual_ckpts
6
+ basetts_outputs_package/
7
+ build/
8
+ *.egg-info/
9
+
10
+ *.zip
11
+ *.wav
MeloTTS/Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+ WORKDIR /app
3
+ COPY . /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential libsndfile1 \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ RUN pip install -e .
10
+ RUN python -m unidic download
11
+ RUN python melo/init_downloads.py
12
+
13
+ CMD ["python", "./melo/app.py", "--host", "0.0.0.0", "--port", "8888"]
MeloTTS/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2024 MyShell.ai
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
MeloTTS/README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <div>&nbsp;</div>
3
+ <img src="logo.png" width="300"/> <br>
4
+ <a href="https://trendshift.io/repositories/8133" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8133" alt="myshell-ai%2FMeloTTS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
5
+ </div>
6
+
7
+ ## Introduction
8
+ MeloTTS is a **high-quality multi-lingual** text-to-speech library by [MIT](https://www.mit.edu/) and [MyShell.ai](https://myshell.ai). Supported languages include:
9
+
10
+ | Language | Example |
11
+ | --- | --- |
12
+ | English (American) | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/en/EN-US/speed_1.0/sent_000.wav) |
13
+ | English (British) | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/en/EN-BR/speed_1.0/sent_000.wav) |
14
+ | English (Indian) | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/en/EN_INDIA/speed_1.0/sent_000.wav) |
15
+ | English (Australian) | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/en/EN-AU/speed_1.0/sent_000.wav) |
16
+ | English (Default) | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/en/EN-Default/speed_1.0/sent_000.wav) |
17
+ | Spanish | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/es/ES/speed_1.0/sent_000.wav) |
18
+ | French | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/fr/FR/speed_1.0/sent_000.wav) |
19
+ | Chinese (mix EN) | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/zh/ZH/speed_1.0/sent_008.wav) |
20
+ | Japanese | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/jp/JP/speed_1.0/sent_000.wav) |
21
+ | Korean | [Link](https://myshell-public-repo-host.s3.amazonaws.com/myshellttsbase/examples/kr/KR/speed_1.0/sent_000.wav) |
22
+
23
+ Some other features include:
24
+ - The Chinese speaker supports `mixed Chinese and English`.
25
+ - Fast enough for `CPU real-time inference`.
26
+
27
+ ## Usage
28
+ - [Use without Installation](docs/quick_use.md)
29
+ - [Install and Use Locally](docs/install.md)
30
+ - [Training on Custom Dataset](docs/training.md)
31
+
32
+ The Python API and model cards can be found in [this repo](https://github.com/myshell-ai/MeloTTS/blob/main/docs/install.md#python-api) or on [HuggingFace](https://huggingface.co/myshell-ai).
33
+
34
+ **Contributing**
35
+
36
+ If you find this work useful, please consider contributing to this repo.
37
+
38
+ - Many thanks to [@fakerybakery](https://github.com/fakerybakery) for adding the Web UI and CLI part.
39
+
40
+ ## Authors
41
+
42
+ - [Wenliang Zhao](https://wl-zhao.github.io) at Tsinghua University
43
+ - [Xumin Yu](https://yuxumin.github.io) at Tsinghua University
44
+ - [Zengyi Qin](https://www.qinzy.tech) (project lead) at MIT and MyShell
45
+
46
+ **Citation**
47
+ ```
48
+ @software{zhao2024melo,
49
+ author={Zhao, Wenliang and Yu, Xumin and Qin, Zengyi},
50
+ title = {MeloTTS: High-quality Multi-lingual Multi-accent Text-to-Speech},
51
+ url = {https://github.com/myshell-ai/MeloTTS},
52
+ year = {2023}
53
+ }
54
+ ```
55
+
56
+ ## License
57
+
58
+ This library is under MIT License, which means it is free for both commercial and non-commercial use.
59
+
60
+ ## Acknowledgements
61
+
62
+ This implementation is based on [TTS](https://github.com/coqui-ai/TTS), [VITS](https://github.com/jaywalnut310/vits), [VITS2](https://github.com/daniilrobnikov/vits2) and [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2). We appreciate their awesome work.
MeloTTS/docs/install.md ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Install and Use Locally
2
+
3
+ ### Table of Content
4
+ - [Linux and macOS Install](#linux-and-macos-install)
5
+ - [Docker Install for Windows and macOS](#docker-install)
6
+ - [Usage](#usage)
7
+ - [Web UI](#webui)
8
+ - [CLI](#cli)
9
+ - [Python API](#python-api)
10
+
11
+ ### Linux and macOS Install
12
+ The repo is developed and tested on `Ubuntu 20.04` and `Python 3.9`.
13
+ ```bash
14
+ git clone https://github.com/myshell-ai/MeloTTS.git
15
+ cd MeloTTS
16
+ pip install -e .
17
+ python -m unidic download
18
+ ```
19
+ If you encountered issues in macOS install, try the [Docker Install](#docker-install)
20
+
21
+ ### Docker Install
22
+ To avoid compatibility issues, for Windows users and some macOS users, we suggest to run via Docker. Ensure that [you have Docker installed](https://docs.docker.com/engine/install/).
23
+
24
+ **Build Docker**
25
+
26
+ This could take a few minutes.
27
+ ```bash
28
+ git clone https://github.com/myshell-ai/MeloTTS.git
29
+ cd MeloTTS
30
+ docker build -t melotts .
31
+ ```
32
+
33
+ **Run Docker**
34
+ ```bash
35
+ docker run -it -p 8888:8888 melotts
36
+ ```
37
+ If your local machine has GPU, then you can choose to run:
38
+ ```bash
39
+ docker run --gpus all -it -p 8888:8888 melotts
40
+ ```
41
+ Then open [http://localhost:8888](http://localhost:8888) in your browser to use the app.
42
+
43
+ ## Usage
44
+
45
+ ### WebUI
46
+
47
+ The WebUI supports muliple languages and voices. First, follow the installation steps. Then, simply run:
48
+
49
+ ```bash
50
+ melo-ui
51
+ # Or: python melo/app.py
52
+ ```
53
+
54
+ ### CLI
55
+
56
+ You may use the MeloTTS CLI to interact with MeloTTS. The CLI may be invoked using either `melotts` or `melo`. Here are some examples:
57
+
58
+ **Read English text:**
59
+
60
+ ```bash
61
+ melo "Text to read" output.wav
62
+ ```
63
+
64
+ **Specify a language:**
65
+
66
+ ```bash
67
+ melo "Text to read" output.wav --language EN
68
+ ```
69
+
70
+ **Specify a speaker:**
71
+
72
+ ```bash
73
+ melo "Text to read" output.wav --language EN --speaker EN-US
74
+ melo "Text to read" output.wav --language EN --speaker EN-AU
75
+ ```
76
+
77
+ The available speakers are: `EN-Default`, `EN-US`, `EN-BR`, `EN_INDIA` `EN-AU`.
78
+
79
+ **Specify a speed:**
80
+
81
+ ```bash
82
+ melo "Text to read" output.wav --language EN --speaker EN-US --speed 1.5
83
+ melo "Text to read" output.wav --speed 1.5
84
+ ```
85
+
86
+ **Use a different language:**
87
+
88
+ ```bash
89
+ melo "text-to-speech 领域近年来发展迅速" zh.wav -l ZH
90
+ ```
91
+
92
+ **Load from a file:**
93
+
94
+ ```bash
95
+ melo file.txt out.wav --file
96
+ ```
97
+
98
+ The full API documentation may be found using:
99
+
100
+ ```bash
101
+ melo --help
102
+ ```
103
+
104
+ ### Python API
105
+
106
+ #### English with Multiple Accents
107
+
108
+ ```python
109
+ from melo.api import TTS
110
+
111
+ # Speed is adjustable
112
+ speed = 1.0
113
+
114
+ # CPU is sufficient for real-time inference.
115
+ # You can set it manually to 'cpu' or 'cuda' or 'cuda:0' or 'mps'
116
+ device = 'auto' # Will automatically use GPU if available
117
+
118
+ # English
119
+ text = "Did you ever hear a folk tale about a giant turtle?"
120
+ model = TTS(language='EN', device=device)
121
+ speaker_ids = model.hps.data.spk2id
122
+
123
+ # American accent
124
+ output_path = 'en-us.wav'
125
+ model.tts_to_file(text, speaker_ids['EN-US'], output_path, speed=speed)
126
+
127
+ # British accent
128
+ output_path = 'en-br.wav'
129
+ model.tts_to_file(text, speaker_ids['EN-BR'], output_path, speed=speed)
130
+
131
+ # Indian accent
132
+ output_path = 'en-india.wav'
133
+ model.tts_to_file(text, speaker_ids['EN_INDIA'], output_path, speed=speed)
134
+
135
+ # Australian accent
136
+ output_path = 'en-au.wav'
137
+ model.tts_to_file(text, speaker_ids['EN-AU'], output_path, speed=speed)
138
+
139
+ # Default accent
140
+ output_path = 'en-default.wav'
141
+ model.tts_to_file(text, speaker_ids['EN-Default'], output_path, speed=speed)
142
+
143
+ ```
144
+
145
+ #### Spanish
146
+ ```python
147
+ from melo.api import TTS
148
+
149
+ # Speed is adjustable
150
+ speed = 1.0
151
+
152
+ # CPU is sufficient for real-time inference.
153
+ # You can also change to cuda:0
154
+ device = 'cpu'
155
+
156
+ text = "El resplandor del sol acaricia las olas, pintando el cielo con una paleta deslumbrante."
157
+ model = TTS(language='ES', device=device)
158
+ speaker_ids = model.hps.data.spk2id
159
+
160
+ output_path = 'es.wav'
161
+ model.tts_to_file(text, speaker_ids['ES'], output_path, speed=speed)
162
+ ```
163
+
164
+ #### French
165
+
166
+ ```python
167
+ from melo.api import TTS
168
+
169
+ # Speed is adjustable
170
+ speed = 1.0
171
+ device = 'cpu' # or cuda:0
172
+
173
+ text = "La lueur dorée du soleil caresse les vagues, peignant le ciel d'une palette éblouissante."
174
+ model = TTS(language='FR', device=device)
175
+ speaker_ids = model.hps.data.spk2id
176
+
177
+ output_path = 'fr.wav'
178
+ model.tts_to_file(text, speaker_ids['FR'], output_path, speed=speed)
179
+ ```
180
+
181
+ #### Chinese
182
+
183
+ ```python
184
+ from melo.api import TTS
185
+
186
+ # Speed is adjustable
187
+ speed = 1.0
188
+ device = 'cpu' # or cuda:0
189
+
190
+ text = "我最近在学习machine learning,希望能够在未来的artificial intelligence领域有所建树。"
191
+ model = TTS(language='ZH', device=device)
192
+ speaker_ids = model.hps.data.spk2id
193
+
194
+ output_path = 'zh.wav'
195
+ model.tts_to_file(text, speaker_ids['ZH'], output_path, speed=speed)
196
+ ```
197
+
198
+ #### Japanese
199
+
200
+ ```python
201
+ from melo.api import TTS
202
+
203
+ # Speed is adjustable
204
+ speed = 1.0
205
+ device = 'cpu' # or cuda:0
206
+
207
+ text = "彼は毎朝ジョギングをして体を健康に保っています。"
208
+ model = TTS(language='JP', device=device)
209
+ speaker_ids = model.hps.data.spk2id
210
+
211
+ output_path = 'jp.wav'
212
+ model.tts_to_file(text, speaker_ids['JP'], output_path, speed=speed)
213
+ ```
214
+
215
+ #### Korean
216
+
217
+ ```python
218
+ from melo.api import TTS
219
+
220
+ # Speed is adjustable
221
+ speed = 1.0
222
+ device = 'cpu' # or cuda:0
223
+
224
+ text = "안녕하세요! 오늘은 날씨가 정말 좋네요."
225
+ model = TTS(language='KR', device=device)
226
+ speaker_ids = model.hps.data.spk2id
227
+
228
+ output_path = 'kr.wav'
229
+ model.tts_to_file(text, speaker_ids['KR'], output_path, speed=speed)
230
+ ```
MeloTTS/docs/quick_use.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Use MeloTTS without Installation
2
+
3
+ **Quick Demo**
4
+
5
+ - [Official live demo](https://app.myshell.ai/bot/UN77N3/1709094629) on Myshell.
6
+ - Hugging Face Space [live demo](https://huggingface.co/spaces/mrfakename/MeloTTS).
7
+
8
+ **Use on MyShell**
9
+
10
+ There are hundreds of TTS models on MyShell, much more than MeloTTS. For example:
11
+
12
+ English
13
+ - [gentle British male voice](https://app.myshell.ai/widget/nIfamm)
14
+ - [cheerful young female voice](https://app.myshell.ai/widget/AjIjqy)
15
+ - [sultry and robust male voice](https://app.myshell.ai/widget/zQJJN3)
16
+
17
+ Spanish
18
+ - [voz femenina adorable](https://app.myshell.ai/widget/buIZBf)
19
+ - [voz masculina joven](https://app.myshell.ai/widget/rayuiy)
20
+ - [voz de niña inmadura](https://app.myshell.ai/widget/mYFV3e)
21
+
22
+ French
23
+ - [voix adorable de fille](https://app.myshell.ai/widget/3IfEfy)
24
+ - [voix douce masculine](https://app.myshell.ai/widget/IRR3M3)
25
+ - [voix douce féminine](https://app.myshell.ai/widget/NRbaUj)
26
+
27
+ German
28
+ - [sanfte Männerstimme](https://app.myshell.ai/widget/JFnAn2)
29
+ - [sanfte Frauenstimme](https://app.myshell.ai/widget/MrU7Nb)
30
+ - [unreife Mädchenstimme](https://app.myshell.ai/widget/UFbYBj)
31
+
32
+ Portuguese
33
+ - [voz feminina nítida](https://app.myshell.ai/widget/VzMb6j)
34
+ - [voz de menino imaturo](https://app.myshell.ai/widget/nAzeei)
35
+ - [voz masculina sóbria](https://app.myshell.ai/widget/JZRNJz)
36
+
37
+ Russian
38
+ - [зрелый женский голос](https://app.myshell.ai/widget/6byMZ3)
39
+ - [зрелый мужской голос](https://app.myshell.ai/widget/NB7jmm)
40
+
41
+ Chinese
42
+ - [甜美女声](https://app.myshell.ai/widget/ymeUjm)
43
+ - [青年男声](https://app.myshell.ai/widget/NZnERb)
44
+
45
+ More can be found at the widget center of [MyShell.ai](https://app.myshell.ai/robot-workshop).
MeloTTS/docs/training.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training
2
+
3
+ Before training, please install MeloTTS in dev mode and go to the `melo` folder.
4
+ ```
5
+ pip install -e .
6
+ cd melo
7
+ ```
8
+
9
+ ### Data Preparation
10
+ To train a TTS model, we need to prepare the audio files and a metadata file. We recommend using 44100Hz audio files and the metadata file should have the following format:
11
+
12
+ ```
13
+ path/to/audio_001.wav |<speaker_name>|<language_code>|<text_001>
14
+ path/to/audio_002.wav |<speaker_name>|<language_code>|<text_002>
15
+ ```
16
+ The transcribed text can be obtained by ASR model, (e.g., [whisper](https://github.com/openai/whisper)). An example metadata can be found in `data/example/metadata.list`
17
+
18
+ We can then run the preprocessing code:
19
+ ```
20
+ python preprocess_text.py --metadata data/example/metadata.list
21
+ ```
22
+ A config file `data/example/config.json` will be generated. Feel free to edit some hyper-parameters in that config file (for example, you may decrease the batch size if you have encountered the CUDA out-of-memory issue).
23
+
24
+ ### Training
25
+ The training can be launched by:
26
+ ```
27
+ bash train.sh <path/to/config.json> <num_of_gpus>
28
+ ```
29
+
30
+ We have found for some machine the training will sometimes crash due to an [issue](https://github.com/pytorch/pytorch/issues/2530) of gloo. Therefore, we add an auto-resume wrapper in the `train.sh`.
31
+
32
+ ### Inference
33
+ Simply run:
34
+ ```
35
+ python infer.py --text "<some text here>" -m /path/to/checkpoint/G_<iter>.pth -o <output_dir>
36
+ ```
37
+
MeloTTS/logo.png ADDED

Git LFS Details

  • SHA256: 0f0625ee2514fde1d65b2bd29ba11f67304ba4174f8295624b34eac967078fd6
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB
MeloTTS/melo/__init__.py ADDED
File without changes
MeloTTS/melo/api.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import torch
5
+ import librosa
6
+ import soundfile
7
+ import torchaudio
8
+ import numpy as np
9
+ import torch.nn as nn
10
+ from tqdm import tqdm
11
+ import torch
12
+
13
+ from . import utils
14
+ from . import commons
15
+ from .models import SynthesizerTrn
16
+ from .split_utils import split_sentence
17
+ from .mel_processing import spectrogram_torch, spectrogram_torch_conv
18
+ from .download_utils import load_or_download_config, load_or_download_model
19
+
20
+ class TTS(nn.Module):
21
+ def __init__(self,
22
+ language,
23
+ device='auto',
24
+ use_hf=True,
25
+ config_path=None,
26
+ ckpt_path=None):
27
+ super().__init__()
28
+ if device == 'auto':
29
+ device = 'cpu'
30
+ if torch.cuda.is_available(): device = 'cuda'
31
+ if torch.backends.mps.is_available(): device = 'mps'
32
+ if 'cuda' in device:
33
+ assert torch.cuda.is_available()
34
+
35
+ # config_path =
36
+ hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)
37
+
38
+ num_languages = hps.num_languages
39
+ num_tones = hps.num_tones
40
+ symbols = hps.symbols
41
+
42
+ model = SynthesizerTrn(
43
+ len(symbols),
44
+ hps.data.filter_length // 2 + 1,
45
+ hps.train.segment_size // hps.data.hop_length,
46
+ n_speakers=hps.data.n_speakers,
47
+ num_tones=num_tones,
48
+ num_languages=num_languages,
49
+ **hps.model,
50
+ ).to(device)
51
+
52
+ model.eval()
53
+ self.model = model
54
+ self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
55
+ self.hps = hps
56
+ self.device = device
57
+
58
+ # load state_dict
59
+ checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
60
+ self.model.load_state_dict(checkpoint_dict['model'], strict=True)
61
+
62
+ language = language.split('_')[0]
63
+ self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
64
+
65
+ @staticmethod
66
+ def audio_numpy_concat(segment_data_list, sr, speed=1.):
67
+ audio_segments = []
68
+ for segment_data in segment_data_list:
69
+ audio_segments += segment_data.reshape(-1).tolist()
70
+ audio_segments += [0] * int((sr * 0.05) / speed)
71
+ audio_segments = np.array(audio_segments).astype(np.float32)
72
+ return audio_segments
73
+
74
+ @staticmethod
75
+ def split_sentences_into_pieces(text, language, quiet=False):
76
+ texts = split_sentence(text, language_str=language)
77
+ if not quiet:
78
+ print(" > Text split to sentences.")
79
+ print('\n'.join(texts))
80
+ print(" > ===========================")
81
+ return texts
82
+
83
+ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
84
+ language = self.language
85
+ texts = self.split_sentences_into_pieces(text, language, quiet)
86
+ audio_list = []
87
+ if pbar:
88
+ tx = pbar(texts)
89
+ else:
90
+ if position:
91
+ tx = tqdm(texts, position=position)
92
+ elif quiet:
93
+ tx = texts
94
+ else:
95
+ tx = tqdm(texts)
96
+ for t in tx:
97
+ if language in ['EN', 'ZH_MIX_EN']:
98
+ t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
99
+ device = self.device
100
+ bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
101
+ with torch.no_grad():
102
+ x_tst = phones.to(device).unsqueeze(0)
103
+ tones = tones.to(device).unsqueeze(0)
104
+ lang_ids = lang_ids.to(device).unsqueeze(0)
105
+ bert = bert.to(device).unsqueeze(0)
106
+ ja_bert = ja_bert.to(device).unsqueeze(0)
107
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
108
+ del phones
109
+ speakers = torch.LongTensor([speaker_id]).to(device)
110
+ audio = self.model.infer(
111
+ x_tst,
112
+ x_tst_lengths,
113
+ speakers,
114
+ tones,
115
+ lang_ids,
116
+ bert,
117
+ ja_bert,
118
+ sdp_ratio=sdp_ratio,
119
+ noise_scale=noise_scale,
120
+ noise_scale_w=noise_scale_w,
121
+ length_scale=1. / speed,
122
+ )[0][0, 0].data.cpu().float().numpy()
123
+ del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
124
+ #
125
+ audio_list.append(audio)
126
+ torch.cuda.empty_cache()
127
+ audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
128
+
129
+ if output_path is None:
130
+ return audio
131
+ else:
132
+ if format:
133
+ soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
134
+ else:
135
+ soundfile.write(output_path, audio, self.hps.data.sampling_rate)
MeloTTS/melo/app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WebUI by mrfakename <X @realmrfakename / HF @mrfakename>
2
+ # Demo also available on HF Spaces: https://huggingface.co/spaces/mrfakename/MeloTTS
3
+ import gradio as gr
4
+ import os, torch, io
5
+ # os.system('python -m unidic download')
6
+ print("Make sure you've downloaded unidic (python -m unidic download) for this WebUI to work.")
7
+ from melo.api import TTS
8
+ speed = 1.0
9
+ import tempfile
10
+ import click
11
+ device = 'auto'
12
+ models = {
13
+ 'EN': TTS(language='EN', device=device),
14
+ 'ES': TTS(language='ES', device=device),
15
+ 'FR': TTS(language='FR', device=device),
16
+ 'ZH': TTS(language='ZH', device=device),
17
+ 'JP': TTS(language='JP', device=device),
18
+ 'KR': TTS(language='KR', device=device),
19
+ }
20
+ speaker_ids = models['EN'].hps.data.spk2id
21
+
22
+ default_text_dict = {
23
+ 'EN': 'The field of text-to-speech has seen rapid development recently.',
24
+ 'ES': 'El campo de la conversión de texto a voz ha experimentado un rápido desarrollo recientemente.',
25
+ 'FR': 'Le domaine de la synthèse vocale a connu un développement rapide récemment',
26
+ 'ZH': 'text-to-speech 领域近年来发展迅速',
27
+ 'JP': 'テキスト読み上げの分野は最近急速な発展を遂げています',
28
+ 'KR': '최근 텍스트 음성 변환 분야가 급속도로 발전하고 있습니다.',
29
+ }
30
+
31
+ def synthesize(speaker, text, speed, language, progress=gr.Progress()):
32
+ bio = io.BytesIO()
33
+ models[language].tts_to_file(text, models[language].hps.data.spk2id[speaker], bio, speed=speed, pbar=progress.tqdm, format='wav')
34
+ return bio.getvalue()
35
+ def load_speakers(language, text):
36
+ if text in list(default_text_dict.values()):
37
+ newtext = default_text_dict[language]
38
+ else:
39
+ newtext = text
40
+ return gr.update(value=list(models[language].hps.data.spk2id.keys())[0], choices=list(models[language].hps.data.spk2id.keys())), newtext
41
+ with gr.Blocks() as demo:
42
+ gr.Markdown('# MeloTTS WebUI\n\nA WebUI for MeloTTS.')
43
+ with gr.Group():
44
+ speaker = gr.Dropdown(speaker_ids.keys(), interactive=True, value='EN-US', label='Speaker')
45
+ language = gr.Radio(['EN', 'ES', 'FR', 'ZH', 'JP', 'KR'], label='Language', value='EN')
46
+ speed = gr.Slider(label='Speed', minimum=0.1, maximum=10.0, value=1.0, interactive=True, step=0.1)
47
+ text = gr.Textbox(label="Text to speak", value=default_text_dict['EN'])
48
+ language.input(load_speakers, inputs=[language, text], outputs=[speaker, text])
49
+ btn = gr.Button('Synthesize', variant='primary')
50
+ aud = gr.Audio(interactive=False)
51
+ btn.click(synthesize, inputs=[speaker, text, speed, language], outputs=[aud])
52
+ gr.Markdown('WebUI by [mrfakename](https://twitter.com/realmrfakename).')
53
+ @click.command()
54
+ @click.option('--share', '-s', is_flag=True, show_default=True, default=False, help="Expose a publicly-accessible shared Gradio link usable by anyone with the link. Only share the link with people you trust.")
55
+ @click.option('--host', '-h', default=None)
56
+ @click.option('--port', '-p', type=int, default=None)
57
+ def main(share, host, port):
58
+ demo.queue(api_open=False).launch(show_api=False, share=share, server_name=host, server_port=port)
59
+
60
+ if __name__ == "__main__":
61
+ main()
MeloTTS/melo/attentions.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from . import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+
59
+ self.cond_layer_idx = self.n_layers
60
+ if "gin_channels" in kwargs:
61
+ self.gin_channels = kwargs["gin_channels"]
62
+ if self.gin_channels != 0:
63
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
64
+ self.cond_layer_idx = (
65
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
66
+ )
67
+ assert (
68
+ self.cond_layer_idx < self.n_layers
69
+ ), "cond_layer_idx should be less than n_layers"
70
+ self.drop = nn.Dropout(p_dropout)
71
+ self.attn_layers = nn.ModuleList()
72
+ self.norm_layers_1 = nn.ModuleList()
73
+ self.ffn_layers = nn.ModuleList()
74
+ self.norm_layers_2 = nn.ModuleList()
75
+
76
+ for i in range(self.n_layers):
77
+ self.attn_layers.append(
78
+ MultiHeadAttention(
79
+ hidden_channels,
80
+ hidden_channels,
81
+ n_heads,
82
+ p_dropout=p_dropout,
83
+ window_size=window_size,
84
+ )
85
+ )
86
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
87
+ self.ffn_layers.append(
88
+ FFN(
89
+ hidden_channels,
90
+ hidden_channels,
91
+ filter_channels,
92
+ kernel_size,
93
+ p_dropout=p_dropout,
94
+ )
95
+ )
96
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
97
+
98
+ def forward(self, x, x_mask, g=None):
99
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
100
+ x = x * x_mask
101
+ for i in range(self.n_layers):
102
+ if i == self.cond_layer_idx and g is not None:
103
+ g = self.spk_emb_linear(g.transpose(1, 2))
104
+ g = g.transpose(1, 2)
105
+ x = x + g
106
+ x = x * x_mask
107
+ y = self.attn_layers[i](x, x, attn_mask)
108
+ y = self.drop(y)
109
+ x = self.norm_layers_1[i](x + y)
110
+
111
+ y = self.ffn_layers[i](x, x_mask)
112
+ y = self.drop(y)
113
+ x = self.norm_layers_2[i](x + y)
114
+ x = x * x_mask
115
+ return x
116
+
117
+
118
+ class Decoder(nn.Module):
119
+ def __init__(
120
+ self,
121
+ hidden_channels,
122
+ filter_channels,
123
+ n_heads,
124
+ n_layers,
125
+ kernel_size=1,
126
+ p_dropout=0.0,
127
+ proximal_bias=False,
128
+ proximal_init=True,
129
+ **kwargs
130
+ ):
131
+ super().__init__()
132
+ self.hidden_channels = hidden_channels
133
+ self.filter_channels = filter_channels
134
+ self.n_heads = n_heads
135
+ self.n_layers = n_layers
136
+ self.kernel_size = kernel_size
137
+ self.p_dropout = p_dropout
138
+ self.proximal_bias = proximal_bias
139
+ self.proximal_init = proximal_init
140
+
141
+ self.drop = nn.Dropout(p_dropout)
142
+ self.self_attn_layers = nn.ModuleList()
143
+ self.norm_layers_0 = nn.ModuleList()
144
+ self.encdec_attn_layers = nn.ModuleList()
145
+ self.norm_layers_1 = nn.ModuleList()
146
+ self.ffn_layers = nn.ModuleList()
147
+ self.norm_layers_2 = nn.ModuleList()
148
+ for i in range(self.n_layers):
149
+ self.self_attn_layers.append(
150
+ MultiHeadAttention(
151
+ hidden_channels,
152
+ hidden_channels,
153
+ n_heads,
154
+ p_dropout=p_dropout,
155
+ proximal_bias=proximal_bias,
156
+ proximal_init=proximal_init,
157
+ )
158
+ )
159
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
160
+ self.encdec_attn_layers.append(
161
+ MultiHeadAttention(
162
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
163
+ )
164
+ )
165
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
166
+ self.ffn_layers.append(
167
+ FFN(
168
+ hidden_channels,
169
+ hidden_channels,
170
+ filter_channels,
171
+ kernel_size,
172
+ p_dropout=p_dropout,
173
+ causal=True,
174
+ )
175
+ )
176
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
177
+
178
+ def forward(self, x, x_mask, h, h_mask):
179
+ """
180
+ x: decoder input
181
+ h: encoder output
182
+ """
183
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
184
+ device=x.device, dtype=x.dtype
185
+ )
186
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
187
+ x = x * x_mask
188
+ for i in range(self.n_layers):
189
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
190
+ y = self.drop(y)
191
+ x = self.norm_layers_0[i](x + y)
192
+
193
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
194
+ y = self.drop(y)
195
+ x = self.norm_layers_1[i](x + y)
196
+
197
+ y = self.ffn_layers[i](x, x_mask)
198
+ y = self.drop(y)
199
+ x = self.norm_layers_2[i](x + y)
200
+ x = x * x_mask
201
+ return x
202
+
203
+
204
+ class MultiHeadAttention(nn.Module):
205
+ def __init__(
206
+ self,
207
+ channels,
208
+ out_channels,
209
+ n_heads,
210
+ p_dropout=0.0,
211
+ window_size=None,
212
+ heads_share=True,
213
+ block_length=None,
214
+ proximal_bias=False,
215
+ proximal_init=False,
216
+ ):
217
+ super().__init__()
218
+ assert channels % n_heads == 0
219
+
220
+ self.channels = channels
221
+ self.out_channels = out_channels
222
+ self.n_heads = n_heads
223
+ self.p_dropout = p_dropout
224
+ self.window_size = window_size
225
+ self.heads_share = heads_share
226
+ self.block_length = block_length
227
+ self.proximal_bias = proximal_bias
228
+ self.proximal_init = proximal_init
229
+ self.attn = None
230
+
231
+ self.k_channels = channels // n_heads
232
+ self.conv_q = nn.Conv1d(channels, channels, 1)
233
+ self.conv_k = nn.Conv1d(channels, channels, 1)
234
+ self.conv_v = nn.Conv1d(channels, channels, 1)
235
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
236
+ self.drop = nn.Dropout(p_dropout)
237
+
238
+ if window_size is not None:
239
+ n_heads_rel = 1 if heads_share else n_heads
240
+ rel_stddev = self.k_channels**-0.5
241
+ self.emb_rel_k = nn.Parameter(
242
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
243
+ * rel_stddev
244
+ )
245
+ self.emb_rel_v = nn.Parameter(
246
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
247
+ * rel_stddev
248
+ )
249
+
250
+ nn.init.xavier_uniform_(self.conv_q.weight)
251
+ nn.init.xavier_uniform_(self.conv_k.weight)
252
+ nn.init.xavier_uniform_(self.conv_v.weight)
253
+ if proximal_init:
254
+ with torch.no_grad():
255
+ self.conv_k.weight.copy_(self.conv_q.weight)
256
+ self.conv_k.bias.copy_(self.conv_q.bias)
257
+
258
+ def forward(self, x, c, attn_mask=None):
259
+ q = self.conv_q(x)
260
+ k = self.conv_k(c)
261
+ v = self.conv_v(c)
262
+
263
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
264
+
265
+ x = self.conv_o(x)
266
+ return x
267
+
268
+ def attention(self, query, key, value, mask=None):
269
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
270
+ b, d, t_s, t_t = (*key.size(), query.size(2))
271
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
272
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
273
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
274
+
275
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
276
+ if self.window_size is not None:
277
+ assert (
278
+ t_s == t_t
279
+ ), "Relative attention is only available for self-attention."
280
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
281
+ rel_logits = self._matmul_with_relative_keys(
282
+ query / math.sqrt(self.k_channels), key_relative_embeddings
283
+ )
284
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
285
+ scores = scores + scores_local
286
+ if self.proximal_bias:
287
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
288
+ scores = scores + self._attention_bias_proximal(t_s).to(
289
+ device=scores.device, dtype=scores.dtype
290
+ )
291
+ if mask is not None:
292
+ scores = scores.masked_fill(mask == 0, -1e4)
293
+ if self.block_length is not None:
294
+ assert (
295
+ t_s == t_t
296
+ ), "Local attention is only available for self-attention."
297
+ block_mask = (
298
+ torch.ones_like(scores)
299
+ .triu(-self.block_length)
300
+ .tril(self.block_length)
301
+ )
302
+ scores = scores.masked_fill(block_mask == 0, -1e4)
303
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
304
+ p_attn = self.drop(p_attn)
305
+ output = torch.matmul(p_attn, value)
306
+ if self.window_size is not None:
307
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
308
+ value_relative_embeddings = self._get_relative_embeddings(
309
+ self.emb_rel_v, t_s
310
+ )
311
+ output = output + self._matmul_with_relative_values(
312
+ relative_weights, value_relative_embeddings
313
+ )
314
+ output = (
315
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
316
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
317
+ return output, p_attn
318
+
319
+ def _matmul_with_relative_values(self, x, y):
320
+ """
321
+ x: [b, h, l, m]
322
+ y: [h or 1, m, d]
323
+ ret: [b, h, l, d]
324
+ """
325
+ ret = torch.matmul(x, y.unsqueeze(0))
326
+ return ret
327
+
328
+ def _matmul_with_relative_keys(self, x, y):
329
+ """
330
+ x: [b, h, l, d]
331
+ y: [h or 1, m, d]
332
+ ret: [b, h, l, m]
333
+ """
334
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
335
+ return ret
336
+
337
+ def _get_relative_embeddings(self, relative_embeddings, length):
338
+ 2 * self.window_size + 1
339
+ # Pad first before slice to avoid using cond ops.
340
+ pad_length = max(length - (self.window_size + 1), 0)
341
+ slice_start_position = max((self.window_size + 1) - length, 0)
342
+ slice_end_position = slice_start_position + 2 * length - 1
343
+ if pad_length > 0:
344
+ padded_relative_embeddings = F.pad(
345
+ relative_embeddings,
346
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
347
+ )
348
+ else:
349
+ padded_relative_embeddings = relative_embeddings
350
+ used_relative_embeddings = padded_relative_embeddings[
351
+ :, slice_start_position:slice_end_position
352
+ ]
353
+ return used_relative_embeddings
354
+
355
+ def _relative_position_to_absolute_position(self, x):
356
+ """
357
+ x: [b, h, l, 2*l-1]
358
+ ret: [b, h, l, l]
359
+ """
360
+ batch, heads, length, _ = x.size()
361
+ # Concat columns of pad to shift from relative to absolute indexing.
362
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
363
+
364
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
365
+ x_flat = x.view([batch, heads, length * 2 * length])
366
+ x_flat = F.pad(
367
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
368
+ )
369
+
370
+ # Reshape and slice out the padded elements.
371
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
372
+ :, :, :length, length - 1 :
373
+ ]
374
+ return x_final
375
+
376
+ def _absolute_position_to_relative_position(self, x):
377
+ """
378
+ x: [b, h, l, l]
379
+ ret: [b, h, l, 2*l-1]
380
+ """
381
+ batch, heads, length, _ = x.size()
382
+ # pad along column
383
+ x = F.pad(
384
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
385
+ )
386
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
387
+ # add 0's in the beginning that will skew the elements after reshape
388
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
389
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
390
+ return x_final
391
+
392
+ def _attention_bias_proximal(self, length):
393
+ """Bias for self-attention to encourage attention to close positions.
394
+ Args:
395
+ length: an integer scalar.
396
+ Returns:
397
+ a Tensor with shape [1, 1, length, length]
398
+ """
399
+ r = torch.arange(length, dtype=torch.float32)
400
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
401
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
402
+
403
+
404
+ class FFN(nn.Module):
405
+ def __init__(
406
+ self,
407
+ in_channels,
408
+ out_channels,
409
+ filter_channels,
410
+ kernel_size,
411
+ p_dropout=0.0,
412
+ activation=None,
413
+ causal=False,
414
+ ):
415
+ super().__init__()
416
+ self.in_channels = in_channels
417
+ self.out_channels = out_channels
418
+ self.filter_channels = filter_channels
419
+ self.kernel_size = kernel_size
420
+ self.p_dropout = p_dropout
421
+ self.activation = activation
422
+ self.causal = causal
423
+
424
+ if causal:
425
+ self.padding = self._causal_padding
426
+ else:
427
+ self.padding = self._same_padding
428
+
429
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
430
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
431
+ self.drop = nn.Dropout(p_dropout)
432
+
433
+ def forward(self, x, x_mask):
434
+ x = self.conv_1(self.padding(x * x_mask))
435
+ if self.activation == "gelu":
436
+ x = x * torch.sigmoid(1.702 * x)
437
+ else:
438
+ x = torch.relu(x)
439
+ x = self.drop(x)
440
+ x = self.conv_2(self.padding(x * x_mask))
441
+ return x * x_mask
442
+
443
+ def _causal_padding(self, x):
444
+ if self.kernel_size == 1:
445
+ return x
446
+ pad_l = self.kernel_size - 1
447
+ pad_r = 0
448
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
449
+ x = F.pad(x, commons.convert_pad_shape(padding))
450
+ return x
451
+
452
+ def _same_padding(self, x):
453
+ if self.kernel_size == 1:
454
+ return x
455
+ pad_l = (self.kernel_size - 1) // 2
456
+ pad_r = self.kernel_size // 2
457
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
458
+ x = F.pad(x, commons.convert_pad_shape(padding))
459
+ return x
MeloTTS/melo/commons.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ layer = pad_shape[::-1]
18
+ pad_shape = [item for sublist in layer for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ layer = pad_shape[::-1]
112
+ pad_shape = [item for sublist in layer for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+
134
+ b, _, t_y, t_x = mask.shape
135
+ cum_duration = torch.cumsum(duration, -1)
136
+
137
+ cum_duration_flat = cum_duration.view(b * t_x)
138
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
139
+ path = path.view(b, t_x, t_y)
140
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
141
+ path = path.unsqueeze(1).transpose(2, 3) * mask
142
+ return path
143
+
144
+
145
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
146
+ if isinstance(parameters, torch.Tensor):
147
+ parameters = [parameters]
148
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
149
+ norm_type = float(norm_type)
150
+ if clip_value is not None:
151
+ clip_value = float(clip_value)
152
+
153
+ total_norm = 0
154
+ for p in parameters:
155
+ param_norm = p.grad.data.norm(norm_type)
156
+ total_norm += param_norm.item() ** norm_type
157
+ if clip_value is not None:
158
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
159
+ total_norm = total_norm ** (1.0 / norm_type)
160
+ return total_norm
MeloTTS/melo/configs/config.json ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 52,
6
+ "epochs": 10000,
7
+ "learning_rate": 0.0003,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 6,
14
+ "fp16_run": false,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 16384,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "skip_optimizer": true
22
+ },
23
+ "data": {
24
+ "training_files": "",
25
+ "validation_files": "",
26
+ "max_wav_value": 32768.0,
27
+ "sampling_rate": 44100,
28
+ "filter_length": 2048,
29
+ "hop_length": 512,
30
+ "win_length": 2048,
31
+ "n_mel_channels": 128,
32
+ "mel_fmin": 0.0,
33
+ "mel_fmax": null,
34
+ "add_blank": true,
35
+ "n_speakers": 256,
36
+ "cleaned_text": true,
37
+ "spk2id": {}
38
+ },
39
+ "model": {
40
+ "use_spk_conditioned_encoder": true,
41
+ "use_noise_scaled_mas": true,
42
+ "use_mel_posterior_encoder": false,
43
+ "use_duration_discriminator": true,
44
+ "inter_channels": 192,
45
+ "hidden_channels": 192,
46
+ "filter_channels": 768,
47
+ "n_heads": 2,
48
+ "n_layers": 6,
49
+ "n_layers_trans_flow": 3,
50
+ "kernel_size": 3,
51
+ "p_dropout": 0.1,
52
+ "resblock": "1",
53
+ "resblock_kernel_sizes": [
54
+ 3,
55
+ 7,
56
+ 11
57
+ ],
58
+ "resblock_dilation_sizes": [
59
+ [
60
+ 1,
61
+ 3,
62
+ 5
63
+ ],
64
+ [
65
+ 1,
66
+ 3,
67
+ 5
68
+ ],
69
+ [
70
+ 1,
71
+ 3,
72
+ 5
73
+ ]
74
+ ],
75
+ "upsample_rates": [
76
+ 8,
77
+ 8,
78
+ 2,
79
+ 2,
80
+ 2
81
+ ],
82
+ "upsample_initial_channel": 512,
83
+ "upsample_kernel_sizes": [
84
+ 16,
85
+ 16,
86
+ 8,
87
+ 2,
88
+ 2
89
+ ],
90
+ "n_layers_q": 3,
91
+ "use_spectral_norm": false,
92
+ "gin_channels": 256
93
+ }
94
+ }
MeloTTS/melo/data/example/metadata.list ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/example/wavs/000.wav|EN-default|EN|Well, there are always new trends and styles emerging in the fashion world, but I think some of the biggest trends at the moment include sustainability and ethical fashion, streetwear and athleisure, and oversized and deconstructed silhouettes.
2
+ data/example/wavs/001.wav|EN-default|EN|Many designers and brands are focusing on creating more environmentally-friendly and socially responsible clothing, while others are incorporating elements of sportswear and casual wear into their collections.
3
+ data/example/wavs/002.wav|EN-default|EN|And there's a growing interest in looser, more relaxed shapes and unconventional materials and finishes.
4
+ data/example/wavs/003.wav|EN-default|EN|That's really insightful.
5
+ data/example/wavs/004.wav|EN-default|EN|What do you think are some of the benefits of following fashion trends?
6
+ data/example/wavs/005.wav|EN-default|EN|Well, I think one of the main benefits of following fashion trends is that it can be a way to express your creativity, personality, and individuality.
7
+ data/example/wavs/006.wav|EN-default|EN|Fashion can be a powerful tool for self-expression and can help you feel more confident and comfortable in your own skin.
8
+ data/example/wavs/007.wav|EN-default|EN|Additionally, staying up-to-date with fashion trends can help you develop your own sense of style and learn how to put together outfits that make you look and feel great.
9
+ data/example/wavs/008.wav|EN-default|EN|That's a great point.
10
+ data/example/wavs/009.wav|EN-default|EN|Do you think it's important to stay on top of the latest fashion trends, or is it more important to focus on timeless style?
11
+ data/example/wavs/010.wav|EN-default|EN|I think it's really up to each individual to decide what approach to fashion works best for them.
12
+ data/example/wavs/011.wav|EN-default|EN|Some people prefer to stick with classic, timeless styles that never go out of fashion, while others enjoy experimenting with new and innovative trends.
13
+ data/example/wavs/012.wav|EN-default|EN|Ultimately, fashion is about personal expression and there's no right or wrong way to approach it.
14
+ data/example/wavs/013.wav|EN-default|EN|The most important thing is to wear what makes you feel good and confident.
15
+ data/example/wavs/014.wav|EN-default|EN|I completely agree.
16
+ data/example/wavs/015.wav|EN-default|EN|Some popular ones that come to mind are oversized blazers, statement sleeves, printed maxi dresses, and chunky sneakers.
17
+ data/example/wavs/016.wav|EN-default|EN|It's been really interesting chatting with you about fashion.
18
+ data/example/wavs/017.wav|EN-default|EN|That's a good point.
19
+ data/example/wavs/018.wav|EN-default|EN|What do you think are some current fashion trends that are popular right now?
20
+ data/example/wavs/019.wav|EN-default|EN|There are so many trends happening right now, it's hard to keep track of them all!
MeloTTS/melo/data_utils.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import torch.utils.data
5
+ from tqdm import tqdm
6
+ from loguru import logger
7
+ import commons
8
+ from mel_processing import spectrogram_torch, mel_spectrogram_torch
9
+ from utils import load_filepaths_and_text
10
+ from utils import load_wav_to_torch_librosa as load_wav_to_torch
11
+ from text import cleaned_text_to_sequence, get_bert
12
+ import numpy as np
13
+
14
+ """Multi speaker version"""
15
+
16
+
17
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
18
+ """
19
+ 1) loads audio, speaker_id, text pairs
20
+ 2) normalizes text and converts them to sequences of integers
21
+ 3) computes spectrograms from audio files.
22
+ """
23
+
24
+ def __init__(self, audiopaths_sid_text, hparams):
25
+ self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
26
+ self.max_wav_value = hparams.max_wav_value
27
+ self.sampling_rate = hparams.sampling_rate
28
+ self.filter_length = hparams.filter_length
29
+ self.hop_length = hparams.hop_length
30
+ self.win_length = hparams.win_length
31
+ self.sampling_rate = hparams.sampling_rate
32
+ self.spk_map = hparams.spk2id
33
+ self.hparams = hparams
34
+ self.disable_bert = getattr(hparams, "disable_bert", False)
35
+
36
+ self.use_mel_spec_posterior = getattr(
37
+ hparams, "use_mel_posterior_encoder", False
38
+ )
39
+ if self.use_mel_spec_posterior:
40
+ self.n_mel_channels = getattr(hparams, "n_mel_channels", 80)
41
+
42
+ self.cleaned_text = getattr(hparams, "cleaned_text", False)
43
+
44
+ self.add_blank = hparams.add_blank
45
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
46
+ self.max_text_len = getattr(hparams, "max_text_len", 300)
47
+
48
+ random.seed(1234)
49
+ random.shuffle(self.audiopaths_sid_text)
50
+ self._filter()
51
+
52
+
53
+ def _filter(self):
54
+ """
55
+ Filter text & store spec lengths
56
+ """
57
+ # Store spectrogram lengths for Bucketing
58
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
59
+ # spec_length = wav_length // hop_length
60
+
61
+ audiopaths_sid_text_new = []
62
+ lengths = []
63
+ skipped = 0
64
+ logger.info("Init dataset...")
65
+ for item in tqdm(
66
+ self.audiopaths_sid_text
67
+ ):
68
+ try:
69
+ _id, spk, language, text, phones, tone, word2ph = item
70
+ except:
71
+ print(item)
72
+ raise
73
+ audiopath = f"{_id}"
74
+ if self.min_text_len <= len(phones) and len(phones) <= self.max_text_len:
75
+ phones = phones.split(" ")
76
+ tone = [int(i) for i in tone.split(" ")]
77
+ word2ph = [int(i) for i in word2ph.split(" ")]
78
+ audiopaths_sid_text_new.append(
79
+ [audiopath, spk, language, text, phones, tone, word2ph]
80
+ )
81
+ lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
82
+ else:
83
+ skipped += 1
84
+ logger.info(f'min: {min(lengths)}; max: {max(lengths)}' )
85
+ logger.info(
86
+ "skipped: "
87
+ + str(skipped)
88
+ + ", total: "
89
+ + str(len(self.audiopaths_sid_text))
90
+ )
91
+ self.audiopaths_sid_text = audiopaths_sid_text_new
92
+ self.lengths = lengths
93
+
94
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
95
+ # separate filename, speaker_id and text
96
+ audiopath, sid, language, text, phones, tone, word2ph = audiopath_sid_text
97
+
98
+ bert, ja_bert, phones, tone, language = self.get_text(
99
+ text, word2ph, phones, tone, language, audiopath
100
+ )
101
+
102
+ spec, wav = self.get_audio(audiopath)
103
+ sid = int(getattr(self.spk_map, sid, '0'))
104
+ sid = torch.LongTensor([sid])
105
+ return (phones, spec, wav, sid, tone, language, bert, ja_bert)
106
+
107
+ def get_audio(self, filename):
108
+ audio_norm, sampling_rate = load_wav_to_torch(filename, self.sampling_rate)
109
+ if sampling_rate != self.sampling_rate:
110
+ raise ValueError(
111
+ "{} {} SR doesn't match target {} SR".format(
112
+ filename, sampling_rate, self.sampling_rate
113
+ )
114
+ )
115
+ # NOTE: normalize has been achieved by torchaudio
116
+ # audio_norm = audio / self.max_wav_value
117
+ audio_norm = audio_norm.unsqueeze(0)
118
+ spec_filename = filename.replace(".wav", ".spec.pt")
119
+ if self.use_mel_spec_posterior:
120
+ spec_filename = spec_filename.replace(".spec.pt", ".mel.pt")
121
+ try:
122
+ spec = torch.load(spec_filename)
123
+ assert False
124
+ except:
125
+ if self.use_mel_spec_posterior:
126
+ spec = mel_spectrogram_torch(
127
+ audio_norm,
128
+ self.filter_length,
129
+ self.n_mel_channels,
130
+ self.sampling_rate,
131
+ self.hop_length,
132
+ self.win_length,
133
+ self.hparams.mel_fmin,
134
+ self.hparams.mel_fmax,
135
+ center=False,
136
+ )
137
+ else:
138
+ spec = spectrogram_torch(
139
+ audio_norm,
140
+ self.filter_length,
141
+ self.sampling_rate,
142
+ self.hop_length,
143
+ self.win_length,
144
+ center=False,
145
+ )
146
+ spec = torch.squeeze(spec, 0)
147
+ torch.save(spec, spec_filename)
148
+ return spec, audio_norm
149
+
150
+ def get_text(self, text, word2ph, phone, tone, language_str, wav_path):
151
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
152
+ if self.add_blank:
153
+ phone = commons.intersperse(phone, 0)
154
+ tone = commons.intersperse(tone, 0)
155
+ language = commons.intersperse(language, 0)
156
+ for i in range(len(word2ph)):
157
+ word2ph[i] = word2ph[i] * 2
158
+ word2ph[0] += 1
159
+ bert_path = wav_path.replace(".wav", ".bert.pt")
160
+ try:
161
+ bert = torch.load(bert_path)
162
+ assert bert.shape[-1] == len(phone)
163
+ except Exception as e:
164
+ print(e, wav_path, bert_path, bert.shape, len(phone))
165
+ bert = get_bert(text, word2ph, language_str)
166
+ torch.save(bert, bert_path)
167
+ assert bert.shape[-1] == len(phone), phone
168
+
169
+ if self.disable_bert:
170
+ bert = torch.zeros(1024, len(phone))
171
+ ja_bert = torch.zeros(768, len(phone))
172
+ else:
173
+ if language_str in ["ZH"]:
174
+ bert = bert
175
+ ja_bert = torch.zeros(768, len(phone))
176
+ elif language_str in ["JP", "EN", "ZH_MIX_EN", "KR", 'SP', 'ES', 'FR', 'DE', 'RU']:
177
+ ja_bert = bert
178
+ bert = torch.zeros(1024, len(phone))
179
+ else:
180
+ raise
181
+ bert = torch.zeros(1024, len(phone))
182
+ ja_bert = torch.zeros(768, len(phone))
183
+ assert bert.shape[-1] == len(phone)
184
+ phone = torch.LongTensor(phone)
185
+ tone = torch.LongTensor(tone)
186
+ language = torch.LongTensor(language)
187
+ return bert, ja_bert, phone, tone, language
188
+
189
+ def get_sid(self, sid):
190
+ sid = torch.LongTensor([int(sid)])
191
+ return sid
192
+
193
+ def __getitem__(self, index):
194
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
195
+
196
+ def __len__(self):
197
+ return len(self.audiopaths_sid_text)
198
+
199
+
200
+ class TextAudioSpeakerCollate:
201
+ """Zero-pads model inputs and targets"""
202
+
203
+ def __init__(self, return_ids=False):
204
+ self.return_ids = return_ids
205
+
206
+ def __call__(self, batch):
207
+ """Collate's training batch from normalized text, audio and speaker identities
208
+ PARAMS
209
+ ------
210
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
211
+ """
212
+ # Right zero-pad all one-hot text sequences to max input length
213
+ _, ids_sorted_decreasing = torch.sort(
214
+ torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
215
+ )
216
+
217
+ max_text_len = max([len(x[0]) for x in batch])
218
+ max_spec_len = max([x[1].size(1) for x in batch])
219
+ max_wav_len = max([x[2].size(1) for x in batch])
220
+
221
+ text_lengths = torch.LongTensor(len(batch))
222
+ spec_lengths = torch.LongTensor(len(batch))
223
+ wav_lengths = torch.LongTensor(len(batch))
224
+ sid = torch.LongTensor(len(batch))
225
+
226
+ text_padded = torch.LongTensor(len(batch), max_text_len)
227
+ tone_padded = torch.LongTensor(len(batch), max_text_len)
228
+ language_padded = torch.LongTensor(len(batch), max_text_len)
229
+ bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
230
+ ja_bert_padded = torch.FloatTensor(len(batch), 768, max_text_len)
231
+
232
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
233
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
234
+ text_padded.zero_()
235
+ tone_padded.zero_()
236
+ language_padded.zero_()
237
+ spec_padded.zero_()
238
+ wav_padded.zero_()
239
+ bert_padded.zero_()
240
+ ja_bert_padded.zero_()
241
+ for i in range(len(ids_sorted_decreasing)):
242
+ row = batch[ids_sorted_decreasing[i]]
243
+
244
+ text = row[0]
245
+ text_padded[i, : text.size(0)] = text
246
+ text_lengths[i] = text.size(0)
247
+
248
+ spec = row[1]
249
+ spec_padded[i, :, : spec.size(1)] = spec
250
+ spec_lengths[i] = spec.size(1)
251
+
252
+ wav = row[2]
253
+ wav_padded[i, :, : wav.size(1)] = wav
254
+ wav_lengths[i] = wav.size(1)
255
+
256
+ sid[i] = row[3]
257
+
258
+ tone = row[4]
259
+ tone_padded[i, : tone.size(0)] = tone
260
+
261
+ language = row[5]
262
+ language_padded[i, : language.size(0)] = language
263
+
264
+ bert = row[6]
265
+ bert_padded[i, :, : bert.size(1)] = bert
266
+
267
+ ja_bert = row[7]
268
+ ja_bert_padded[i, :, : ja_bert.size(1)] = ja_bert
269
+
270
+ return (
271
+ text_padded,
272
+ text_lengths,
273
+ spec_padded,
274
+ spec_lengths,
275
+ wav_padded,
276
+ wav_lengths,
277
+ sid,
278
+ tone_padded,
279
+ language_padded,
280
+ bert_padded,
281
+ ja_bert_padded,
282
+ )
283
+
284
+
285
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
286
+ """
287
+ Maintain similar input lengths in a batch.
288
+ Length groups are specified by boundaries.
289
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
290
+
291
+ It removes samples which are not included in the boundaries.
292
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ dataset,
298
+ batch_size,
299
+ boundaries,
300
+ num_replicas=None,
301
+ rank=None,
302
+ shuffle=True,
303
+ ):
304
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
305
+ self.lengths = dataset.lengths
306
+ self.batch_size = batch_size
307
+ self.boundaries = boundaries
308
+
309
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
310
+ self.total_size = sum(self.num_samples_per_bucket)
311
+ self.num_samples = self.total_size // self.num_replicas
312
+ print('buckets:', self.num_samples_per_bucket)
313
+
314
+ def _create_buckets(self):
315
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
316
+ for i in range(len(self.lengths)):
317
+ length = self.lengths[i]
318
+ idx_bucket = self._bisect(length)
319
+ if idx_bucket != -1:
320
+ buckets[idx_bucket].append(i)
321
+
322
+ try:
323
+ for i in range(len(buckets) - 1, 0, -1):
324
+ if len(buckets[i]) == 0:
325
+ buckets.pop(i)
326
+ self.boundaries.pop(i + 1)
327
+ assert all(len(bucket) > 0 for bucket in buckets)
328
+ # When one bucket is not traversed
329
+ except Exception as e:
330
+ print("Bucket warning ", e)
331
+ for i in range(len(buckets) - 1, -1, -1):
332
+ if len(buckets[i]) == 0:
333
+ buckets.pop(i)
334
+ self.boundaries.pop(i + 1)
335
+
336
+ num_samples_per_bucket = []
337
+ for i in range(len(buckets)):
338
+ len_bucket = len(buckets[i])
339
+ total_batch_size = self.num_replicas * self.batch_size
340
+ rem = (
341
+ total_batch_size - (len_bucket % total_batch_size)
342
+ ) % total_batch_size
343
+ num_samples_per_bucket.append(len_bucket + rem)
344
+ return buckets, num_samples_per_bucket
345
+
346
+ def __iter__(self):
347
+ # deterministically shuffle based on epoch
348
+ g = torch.Generator()
349
+ g.manual_seed(self.epoch)
350
+
351
+ indices = []
352
+ if self.shuffle:
353
+ for bucket in self.buckets:
354
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
355
+ else:
356
+ for bucket in self.buckets:
357
+ indices.append(list(range(len(bucket))))
358
+
359
+ batches = []
360
+ for i in range(len(self.buckets)):
361
+ bucket = self.buckets[i]
362
+ len_bucket = len(bucket)
363
+ if len_bucket == 0:
364
+ continue
365
+ ids_bucket = indices[i]
366
+ num_samples_bucket = self.num_samples_per_bucket[i]
367
+
368
+ # add extra samples to make it evenly divisible
369
+ rem = num_samples_bucket - len_bucket
370
+ ids_bucket = (
371
+ ids_bucket
372
+ + ids_bucket * (rem // len_bucket)
373
+ + ids_bucket[: (rem % len_bucket)]
374
+ )
375
+
376
+ # subsample
377
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
378
+
379
+ # batching
380
+ for j in range(len(ids_bucket) // self.batch_size):
381
+ batch = [
382
+ bucket[idx]
383
+ for idx in ids_bucket[
384
+ j * self.batch_size : (j + 1) * self.batch_size
385
+ ]
386
+ ]
387
+ batches.append(batch)
388
+
389
+ if self.shuffle:
390
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
391
+ batches = [batches[i] for i in batch_ids]
392
+ self.batches = batches
393
+
394
+ assert len(self.batches) * self.batch_size == self.num_samples
395
+ return iter(self.batches)
396
+
397
+ def _bisect(self, x, lo=0, hi=None):
398
+ if hi is None:
399
+ hi = len(self.boundaries) - 1
400
+
401
+ if hi > lo:
402
+ mid = (hi + lo) // 2
403
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
404
+ return mid
405
+ elif x <= self.boundaries[mid]:
406
+ return self._bisect(x, lo, mid)
407
+ else:
408
+ return self._bisect(x, mid + 1, hi)
409
+ else:
410
+ return -1
411
+
412
+ def __len__(self):
413
+ return self.num_samples // self.batch_size
MeloTTS/melo/download_utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from . import utils
4
+ from cached_path import cached_path
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ DOWNLOAD_CKPT_URLS = {
8
+ 'EN': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/EN/checkpoint.pth',
9
+ 'EN_V2': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/EN_V2/checkpoint.pth',
10
+ 'FR': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/FR/checkpoint.pth',
11
+ 'JP': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/JP/checkpoint.pth',
12
+ 'ES': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/ES/checkpoint.pth',
13
+ 'ZH': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/ZH/checkpoint.pth',
14
+ 'KR': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/KR/checkpoint.pth',
15
+ }
16
+
17
+ DOWNLOAD_CONFIG_URLS = {
18
+ 'EN': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/EN/config.json',
19
+ 'EN_V2': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/EN_V2/config.json',
20
+ 'FR': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/FR/config.json',
21
+ 'JP': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/JP/config.json',
22
+ 'ES': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/ES/config.json',
23
+ 'ZH': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/ZH/config.json',
24
+ 'KR': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/KR/config.json',
25
+ }
26
+
27
+ PRETRAINED_MODELS = {
28
+ 'G.pth': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/pretrained/G.pth',
29
+ 'D.pth': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/pretrained/D.pth',
30
+ 'DUR.pth': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/pretrained/DUR.pth',
31
+ }
32
+
33
+ LANG_TO_HF_REPO_ID = {
34
+ 'EN': 'myshell-ai/MeloTTS-English',
35
+ 'EN_V2': 'myshell-ai/MeloTTS-English-v2',
36
+ 'EN_NEWEST': 'myshell-ai/MeloTTS-English-v3',
37
+ 'FR': 'myshell-ai/MeloTTS-French',
38
+ 'JP': 'myshell-ai/MeloTTS-Japanese',
39
+ 'ES': 'myshell-ai/MeloTTS-Spanish',
40
+ 'ZH': 'myshell-ai/MeloTTS-Chinese',
41
+ 'KR': 'myshell-ai/MeloTTS-Korean',
42
+ }
43
+
44
+ def load_or_download_config(locale, use_hf=True, config_path=None):
45
+ if config_path is None:
46
+ language = locale.split('-')[0].upper()
47
+ if use_hf:
48
+ assert language in LANG_TO_HF_REPO_ID
49
+ config_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="config.json")
50
+ else:
51
+ assert language in DOWNLOAD_CONFIG_URLS
52
+ config_path = cached_path(DOWNLOAD_CONFIG_URLS[language])
53
+ return utils.get_hparams_from_file(config_path)
54
+
55
+ def load_or_download_model(locale, device, use_hf=True, ckpt_path=None):
56
+ if ckpt_path is None:
57
+ language = locale.split('-')[0].upper()
58
+ if use_hf:
59
+ assert language in LANG_TO_HF_REPO_ID
60
+ ckpt_path = hf_hub_download(repo_id=LANG_TO_HF_REPO_ID[language], filename="checkpoint.pth")
61
+ else:
62
+ assert language in DOWNLOAD_CKPT_URLS
63
+ ckpt_path = cached_path(DOWNLOAD_CKPT_URLS[language])
64
+ return torch.load(ckpt_path, map_location=device)
65
+
66
+ def load_pretrain_model():
67
+ return [cached_path(url) for url in PRETRAINED_MODELS.values()]
MeloTTS/melo/infer.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import click
3
+ from melo.api import TTS
4
+
5
+
6
+
7
+ @click.command()
8
+ @click.option('--ckpt_path', '-m', type=str, default=None, help="Path to the checkpoint file")
9
+ @click.option('--text', '-t', type=str, default=None, help="Text to speak")
10
+ @click.option('--language', '-l', type=str, default="EN", help="Language of the model")
11
+ @click.option('--output_dir', '-o', type=str, default="outputs", help="Path to the output")
12
+ def main(ckpt_path, text, language, output_dir):
13
+ if ckpt_path is None:
14
+ raise ValueError("The model_path must be specified")
15
+
16
+ config_path = os.path.join(os.path.dirname(ckpt_path), 'config.json')
17
+ model = TTS(language=language, config_path=config_path, ckpt_path=ckpt_path)
18
+
19
+ for spk_name, spk_id in model.hps.data.spk2id.items():
20
+ save_path = f'{output_dir}/{spk_name}/output.wav'
21
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
22
+ model.tts_to_file(text, spk_id, save_path)
23
+
24
+ if __name__ == "__main__":
25
+ main()
MeloTTS/melo/init_downloads.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ if __name__ == '__main__':
4
+
5
+ from melo.api import TTS
6
+ device = 'auto'
7
+ models = {
8
+ 'EN': TTS(language='EN', device=device),
9
+ 'ES': TTS(language='ES', device=device),
10
+ 'FR': TTS(language='FR', device=device),
11
+ 'ZH': TTS(language='ZH', device=device),
12
+ 'JP': TTS(language='JP', device=device),
13
+ 'KR': TTS(language='KR', device=device),
14
+ }
MeloTTS/melo/losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def feature_loss(fmap_r, fmap_g):
5
+ loss = 0
6
+ for dr, dg in zip(fmap_r, fmap_g):
7
+ for rl, gl in zip(dr, dg):
8
+ rl = rl.float().detach()
9
+ gl = gl.float()
10
+ loss += torch.mean(torch.abs(rl - gl))
11
+
12
+ return loss * 2
13
+
14
+
15
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
+ loss = 0
17
+ r_losses = []
18
+ g_losses = []
19
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
+ dr = dr.float()
21
+ dg = dg.float()
22
+ r_loss = torch.mean((1 - dr) ** 2)
23
+ g_loss = torch.mean(dg**2)
24
+ loss += r_loss + g_loss
25
+ r_losses.append(r_loss.item())
26
+ g_losses.append(g_loss.item())
27
+
28
+ return loss, r_losses, g_losses
29
+
30
+
31
+ def generator_loss(disc_outputs):
32
+ loss = 0
33
+ gen_losses = []
34
+ for dg in disc_outputs:
35
+ dg = dg.float()
36
+ l = torch.mean((1 - dg) ** 2)
37
+ gen_losses.append(l)
38
+ loss += l
39
+
40
+ return loss, gen_losses
41
+
42
+
43
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
+ """
45
+ z_p, logs_q: [b, h, t_t]
46
+ m_p, logs_p: [b, h, t_t]
47
+ """
48
+ z_p = z_p.float()
49
+ logs_q = logs_q.float()
50
+ m_p = m_p.float()
51
+ logs_p = logs_p.float()
52
+ z_mask = z_mask.float()
53
+
54
+ kl = logs_p - logs_q - 0.5
55
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
56
+ kl = torch.sum(kl * z_mask)
57
+ l = kl / torch.sum(z_mask)
58
+ return l
MeloTTS/melo/main.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import warnings
3
+ import os
4
+
5
+
6
+ @click.command
7
+ @click.argument('text')
8
+ @click.argument('output_path')
9
+ @click.option("--file", '-f', is_flag=True, show_default=True, default=False, help="Text is a file")
10
+ @click.option('--language', '-l', default='EN', help='Language, defaults to English', type=click.Choice(['EN', 'ES', 'FR', 'ZH', 'JP', 'KR'], case_sensitive=False))
11
+ @click.option('--speaker', '-spk', default='EN-Default', help='Speaker ID, only for English, leave empty for default, ignored if not English. If English, defaults to "EN-Default"', type=click.Choice(['EN-Default', 'EN-US', 'EN-BR', 'EN_INDIA', 'EN-AU']))
12
+ @click.option('--speed', '-s', default=1.0, help='Speed, defaults to 1.0', type=float)
13
+ @click.option('--device', '-d', default='auto', help='Device, defaults to auto')
14
+ def main(text, file, output_path, language, speaker, speed, device):
15
+ if file:
16
+ if not os.path.exists(text):
17
+ raise FileNotFoundError(f'Trying to load text from file due to --file/-f flag, but file not found. Remove the --file/-f flag to pass a string.')
18
+ else:
19
+ with open(text) as f:
20
+ text = f.read().strip()
21
+ if text == '':
22
+ raise ValueError('You entered empty text or the file you passed was empty.')
23
+ language = language.upper()
24
+ if language == '': language = 'EN'
25
+ if speaker == '': speaker = None
26
+ if (not language == 'EN') and speaker:
27
+ warnings.warn('You specified a speaker but the language is English.')
28
+ from melo.api import TTS
29
+ model = TTS(language=language, device=device)
30
+ speaker_ids = model.hps.data.spk2id
31
+ if language == 'EN':
32
+ if not speaker: speaker = 'EN-Default'
33
+ spkr = speaker_ids[speaker]
34
+ else:
35
+ spkr = speaker_ids[list(speaker_ids.keys())[0]]
36
+ model.tts_to_file(text, spkr, output_path, speed=speed)
MeloTTS/melo/mel_processing.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ import librosa
4
+ from librosa.filters import mel as librosa_mel_fn
5
+
6
+ MAX_WAV_VALUE = 32768.0
7
+
8
+
9
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
10
+ """
11
+ PARAMS
12
+ ------
13
+ C: compression factor
14
+ """
15
+ return torch.log(torch.clamp(x, min=clip_val) * C)
16
+
17
+
18
+ def dynamic_range_decompression_torch(x, C=1):
19
+ """
20
+ PARAMS
21
+ ------
22
+ C: compression factor used to compress
23
+ """
24
+ return torch.exp(x) / C
25
+
26
+
27
+ def spectral_normalize_torch(magnitudes):
28
+ output = dynamic_range_compression_torch(magnitudes)
29
+ return output
30
+
31
+
32
+ def spectral_de_normalize_torch(magnitudes):
33
+ output = dynamic_range_decompression_torch(magnitudes)
34
+ return output
35
+
36
+
37
+ mel_basis = {}
38
+ hann_window = {}
39
+
40
+
41
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
42
+ if torch.min(y) < -1.1:
43
+ print("min value is ", torch.min(y))
44
+ if torch.max(y) > 1.1:
45
+ print("max value is ", torch.max(y))
46
+
47
+ global hann_window
48
+ dtype_device = str(y.dtype) + "_" + str(y.device)
49
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
50
+ if wnsize_dtype_device not in hann_window:
51
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
52
+ dtype=y.dtype, device=y.device
53
+ )
54
+
55
+ y = torch.nn.functional.pad(
56
+ y.unsqueeze(1),
57
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
58
+ mode="reflect",
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.stft(
63
+ y,
64
+ n_fft,
65
+ hop_length=hop_size,
66
+ win_length=win_size,
67
+ window=hann_window[wnsize_dtype_device],
68
+ center=center,
69
+ pad_mode="reflect",
70
+ normalized=False,
71
+ onesided=True,
72
+ return_complex=False,
73
+ )
74
+
75
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
76
+ return spec
77
+
78
+
79
+ def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
80
+ global hann_window
81
+ dtype_device = str(y.dtype) + '_' + str(y.device)
82
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
83
+ if wnsize_dtype_device not in hann_window:
84
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
85
+
86
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
87
+
88
+ # ******************** original ************************#
89
+ # y = y.squeeze(1)
90
+ # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
91
+ # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
92
+
93
+ # ******************** ConvSTFT ************************#
94
+ freq_cutoff = n_fft // 2 + 1
95
+ fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
96
+ forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
97
+ forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
98
+
99
+ import torch.nn.functional as F
100
+
101
+ # if center:
102
+ # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
103
+ assert center is False
104
+
105
+ forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
106
+ spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
107
+
108
+
109
+ # ******************** Verification ************************#
110
+ spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
111
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
112
+ assert torch.allclose(spec1, spec2, atol=1e-4)
113
+
114
+ spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
115
+ return spec
116
+
117
+
118
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
119
+ global mel_basis
120
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
121
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
122
+ if fmax_dtype_device not in mel_basis:
123
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
124
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
125
+ dtype=spec.dtype, device=spec.device
126
+ )
127
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
128
+ spec = spectral_normalize_torch(spec)
129
+ return spec
130
+
131
+
132
+ def mel_spectrogram_torch(
133
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
134
+ ):
135
+ global mel_basis, hann_window
136
+ dtype_device = str(y.dtype) + "_" + str(y.device)
137
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
138
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
139
+ if fmax_dtype_device not in mel_basis:
140
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
141
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
142
+ dtype=y.dtype, device=y.device
143
+ )
144
+ if wnsize_dtype_device not in hann_window:
145
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
146
+ dtype=y.dtype, device=y.device
147
+ )
148
+
149
+ y = torch.nn.functional.pad(
150
+ y.unsqueeze(1),
151
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
152
+ mode="reflect",
153
+ )
154
+ y = y.squeeze(1)
155
+
156
+ spec = torch.stft(
157
+ y,
158
+ n_fft,
159
+ hop_length=hop_size,
160
+ win_length=win_size,
161
+ window=hann_window[wnsize_dtype_device],
162
+ center=center,
163
+ pad_mode="reflect",
164
+ normalized=False,
165
+ onesided=True,
166
+ return_complex=False,
167
+ )
168
+
169
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
170
+
171
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
172
+ spec = spectral_normalize_torch(spec)
173
+
174
+ return spec
MeloTTS/melo/models.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from melo import commons
7
+ from melo import modules
8
+ from melo import attentions
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+
13
+ from melo.commons import init_weights, get_padding
14
+ import melo.monotonic_align as monotonic_align
15
+
16
+
17
+ class DurationDiscriminator(nn.Module): # vits2
18
+ def __init__(
19
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
20
+ ):
21
+ super().__init__()
22
+ self.in_channels = in_channels
23
+ self.filter_channels = filter_channels
24
+ self.kernel_size = kernel_size
25
+ self.p_dropout = p_dropout
26
+ self.gin_channels = gin_channels
27
+
28
+ self.drop = nn.Dropout(p_dropout)
29
+ self.conv_1 = nn.Conv1d(
30
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
31
+ )
32
+ self.norm_1 = modules.LayerNorm(filter_channels)
33
+ self.conv_2 = nn.Conv1d(
34
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
35
+ )
36
+ self.norm_2 = modules.LayerNorm(filter_channels)
37
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
38
+
39
+ self.pre_out_conv_1 = nn.Conv1d(
40
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
41
+ )
42
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
43
+ self.pre_out_conv_2 = nn.Conv1d(
44
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
45
+ )
46
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
47
+
48
+ if gin_channels != 0:
49
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
50
+
51
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
52
+
53
+ def forward_probability(self, x, x_mask, dur, g=None):
54
+ dur = self.dur_proj(dur)
55
+ x = torch.cat([x, dur], dim=1)
56
+ x = self.pre_out_conv_1(x * x_mask)
57
+ x = torch.relu(x)
58
+ x = self.pre_out_norm_1(x)
59
+ x = self.drop(x)
60
+ x = self.pre_out_conv_2(x * x_mask)
61
+ x = torch.relu(x)
62
+ x = self.pre_out_norm_2(x)
63
+ x = self.drop(x)
64
+ x = x * x_mask
65
+ x = x.transpose(1, 2)
66
+ output_prob = self.output_layer(x)
67
+ return output_prob
68
+
69
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
70
+ x = torch.detach(x)
71
+ if g is not None:
72
+ g = torch.detach(g)
73
+ x = x + self.cond(g)
74
+ x = self.conv_1(x * x_mask)
75
+ x = torch.relu(x)
76
+ x = self.norm_1(x)
77
+ x = self.drop(x)
78
+ x = self.conv_2(x * x_mask)
79
+ x = torch.relu(x)
80
+ x = self.norm_2(x)
81
+ x = self.drop(x)
82
+
83
+ output_probs = []
84
+ for dur in [dur_r, dur_hat]:
85
+ output_prob = self.forward_probability(x, x_mask, dur, g)
86
+ output_probs.append(output_prob)
87
+
88
+ return output_probs
89
+
90
+
91
+ class TransformerCouplingBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ n_heads,
98
+ n_layers,
99
+ kernel_size,
100
+ p_dropout,
101
+ n_flows=4,
102
+ gin_channels=0,
103
+ share_parameter=False,
104
+ ):
105
+ super().__init__()
106
+ self.channels = channels
107
+ self.hidden_channels = hidden_channels
108
+ self.kernel_size = kernel_size
109
+ self.n_layers = n_layers
110
+ self.n_flows = n_flows
111
+ self.gin_channels = gin_channels
112
+
113
+ self.flows = nn.ModuleList()
114
+
115
+ self.wn = (
116
+ attentions.FFT(
117
+ hidden_channels,
118
+ filter_channels,
119
+ n_heads,
120
+ n_layers,
121
+ kernel_size,
122
+ p_dropout,
123
+ isflow=True,
124
+ gin_channels=self.gin_channels,
125
+ )
126
+ if share_parameter
127
+ else None
128
+ )
129
+
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.TransformerCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ n_layers,
137
+ n_heads,
138
+ p_dropout,
139
+ filter_channels,
140
+ mean_only=True,
141
+ wn_sharing_parameter=self.wn,
142
+ gin_channels=self.gin_channels,
143
+ )
144
+ )
145
+ self.flows.append(modules.Flip())
146
+
147
+ def forward(self, x, x_mask, g=None, reverse=False):
148
+ if not reverse:
149
+ for flow in self.flows:
150
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
151
+ else:
152
+ for flow in reversed(self.flows):
153
+ x = flow(x, x_mask, g=g, reverse=reverse)
154
+ return x
155
+
156
+
157
+ class StochasticDurationPredictor(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels,
161
+ filter_channels,
162
+ kernel_size,
163
+ p_dropout,
164
+ n_flows=4,
165
+ gin_channels=0,
166
+ ):
167
+ super().__init__()
168
+ filter_channels = in_channels # it needs to be removed from future version.
169
+ self.in_channels = in_channels
170
+ self.filter_channels = filter_channels
171
+ self.kernel_size = kernel_size
172
+ self.p_dropout = p_dropout
173
+ self.n_flows = n_flows
174
+ self.gin_channels = gin_channels
175
+
176
+ self.log_flow = modules.Log()
177
+ self.flows = nn.ModuleList()
178
+ self.flows.append(modules.ElementwiseAffine(2))
179
+ for i in range(n_flows):
180
+ self.flows.append(
181
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
182
+ )
183
+ self.flows.append(modules.Flip())
184
+
185
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
186
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
187
+ self.post_convs = modules.DDSConv(
188
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
189
+ )
190
+ self.post_flows = nn.ModuleList()
191
+ self.post_flows.append(modules.ElementwiseAffine(2))
192
+ for i in range(4):
193
+ self.post_flows.append(
194
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
195
+ )
196
+ self.post_flows.append(modules.Flip())
197
+
198
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
199
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
200
+ self.convs = modules.DDSConv(
201
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
202
+ )
203
+ if gin_channels != 0:
204
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
205
+
206
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
207
+ x = torch.detach(x)
208
+ x = self.pre(x)
209
+ if g is not None:
210
+ g = torch.detach(g)
211
+ x = x + self.cond(g)
212
+ x = self.convs(x, x_mask)
213
+ x = self.proj(x) * x_mask
214
+
215
+ if not reverse:
216
+ flows = self.flows
217
+ assert w is not None
218
+
219
+ logdet_tot_q = 0
220
+ h_w = self.post_pre(w)
221
+ h_w = self.post_convs(h_w, x_mask)
222
+ h_w = self.post_proj(h_w) * x_mask
223
+ e_q = (
224
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
225
+ * x_mask
226
+ )
227
+ z_q = e_q
228
+ for flow in self.post_flows:
229
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
230
+ logdet_tot_q += logdet_q
231
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
232
+ u = torch.sigmoid(z_u) * x_mask
233
+ z0 = (w - u) * x_mask
234
+ logdet_tot_q += torch.sum(
235
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
236
+ )
237
+ logq = (
238
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
239
+ - logdet_tot_q
240
+ )
241
+
242
+ logdet_tot = 0
243
+ z0, logdet = self.log_flow(z0, x_mask)
244
+ logdet_tot += logdet
245
+ z = torch.cat([z0, z1], 1)
246
+ for flow in flows:
247
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
248
+ logdet_tot = logdet_tot + logdet
249
+ nll = (
250
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
251
+ - logdet_tot
252
+ )
253
+ return nll + logq # [b]
254
+ else:
255
+ flows = list(reversed(self.flows))
256
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
257
+ z = (
258
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
259
+ * noise_scale
260
+ )
261
+ for flow in flows:
262
+ z = flow(z, x_mask, g=x, reverse=reverse)
263
+ z0, z1 = torch.split(z, [1, 1], 1)
264
+ logw = z0
265
+ return logw
266
+
267
+
268
+ class DurationPredictor(nn.Module):
269
+ def __init__(
270
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
271
+ ):
272
+ super().__init__()
273
+
274
+ self.in_channels = in_channels
275
+ self.filter_channels = filter_channels
276
+ self.kernel_size = kernel_size
277
+ self.p_dropout = p_dropout
278
+ self.gin_channels = gin_channels
279
+
280
+ self.drop = nn.Dropout(p_dropout)
281
+ self.conv_1 = nn.Conv1d(
282
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
283
+ )
284
+ self.norm_1 = modules.LayerNorm(filter_channels)
285
+ self.conv_2 = nn.Conv1d(
286
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
287
+ )
288
+ self.norm_2 = modules.LayerNorm(filter_channels)
289
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
290
+
291
+ if gin_channels != 0:
292
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
293
+
294
+ def forward(self, x, x_mask, g=None):
295
+ x = torch.detach(x)
296
+ if g is not None:
297
+ g = torch.detach(g)
298
+ x = x + self.cond(g)
299
+ x = self.conv_1(x * x_mask)
300
+ x = torch.relu(x)
301
+ x = self.norm_1(x)
302
+ x = self.drop(x)
303
+ x = self.conv_2(x * x_mask)
304
+ x = torch.relu(x)
305
+ x = self.norm_2(x)
306
+ x = self.drop(x)
307
+ x = self.proj(x * x_mask)
308
+ return x * x_mask
309
+
310
+
311
+ class TextEncoder(nn.Module):
312
+ def __init__(
313
+ self,
314
+ n_vocab,
315
+ out_channels,
316
+ hidden_channels,
317
+ filter_channels,
318
+ n_heads,
319
+ n_layers,
320
+ kernel_size,
321
+ p_dropout,
322
+ gin_channels=0,
323
+ num_languages=None,
324
+ num_tones=None,
325
+ ):
326
+ super().__init__()
327
+ if num_languages is None:
328
+ from text import num_languages
329
+ if num_tones is None:
330
+ from text import num_tones
331
+ self.n_vocab = n_vocab
332
+ self.out_channels = out_channels
333
+ self.hidden_channels = hidden_channels
334
+ self.filter_channels = filter_channels
335
+ self.n_heads = n_heads
336
+ self.n_layers = n_layers
337
+ self.kernel_size = kernel_size
338
+ self.p_dropout = p_dropout
339
+ self.gin_channels = gin_channels
340
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
341
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
342
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
343
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
344
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
345
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
346
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
347
+ self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
348
+
349
+ self.encoder = attentions.Encoder(
350
+ hidden_channels,
351
+ filter_channels,
352
+ n_heads,
353
+ n_layers,
354
+ kernel_size,
355
+ p_dropout,
356
+ gin_channels=self.gin_channels,
357
+ )
358
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
359
+
360
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
361
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
362
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
363
+ x = (
364
+ self.emb(x)
365
+ + self.tone_emb(tone)
366
+ + self.language_emb(language)
367
+ + bert_emb
368
+ + ja_bert_emb
369
+ ) * math.sqrt(
370
+ self.hidden_channels
371
+ ) # [b, t, h]
372
+ x = torch.transpose(x, 1, -1) # [b, h, t]
373
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
374
+ x.dtype
375
+ )
376
+
377
+ x = self.encoder(x * x_mask, x_mask, g=g)
378
+ stats = self.proj(x) * x_mask
379
+
380
+ m, logs = torch.split(stats, self.out_channels, dim=1)
381
+ return x, m, logs, x_mask
382
+
383
+
384
+ class ResidualCouplingBlock(nn.Module):
385
+ def __init__(
386
+ self,
387
+ channels,
388
+ hidden_channels,
389
+ kernel_size,
390
+ dilation_rate,
391
+ n_layers,
392
+ n_flows=4,
393
+ gin_channels=0,
394
+ ):
395
+ super().__init__()
396
+ self.channels = channels
397
+ self.hidden_channels = hidden_channels
398
+ self.kernel_size = kernel_size
399
+ self.dilation_rate = dilation_rate
400
+ self.n_layers = n_layers
401
+ self.n_flows = n_flows
402
+ self.gin_channels = gin_channels
403
+
404
+ self.flows = nn.ModuleList()
405
+ for i in range(n_flows):
406
+ self.flows.append(
407
+ modules.ResidualCouplingLayer(
408
+ channels,
409
+ hidden_channels,
410
+ kernel_size,
411
+ dilation_rate,
412
+ n_layers,
413
+ gin_channels=gin_channels,
414
+ mean_only=True,
415
+ )
416
+ )
417
+ self.flows.append(modules.Flip())
418
+
419
+ def forward(self, x, x_mask, g=None, reverse=False):
420
+ if not reverse:
421
+ for flow in self.flows:
422
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
423
+ else:
424
+ for flow in reversed(self.flows):
425
+ x = flow(x, x_mask, g=g, reverse=reverse)
426
+ return x
427
+
428
+
429
+ class PosteriorEncoder(nn.Module):
430
+ def __init__(
431
+ self,
432
+ in_channels,
433
+ out_channels,
434
+ hidden_channels,
435
+ kernel_size,
436
+ dilation_rate,
437
+ n_layers,
438
+ gin_channels=0,
439
+ ):
440
+ super().__init__()
441
+ self.in_channels = in_channels
442
+ self.out_channels = out_channels
443
+ self.hidden_channels = hidden_channels
444
+ self.kernel_size = kernel_size
445
+ self.dilation_rate = dilation_rate
446
+ self.n_layers = n_layers
447
+ self.gin_channels = gin_channels
448
+
449
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
450
+ self.enc = modules.WN(
451
+ hidden_channels,
452
+ kernel_size,
453
+ dilation_rate,
454
+ n_layers,
455
+ gin_channels=gin_channels,
456
+ )
457
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
458
+
459
+ def forward(self, x, x_lengths, g=None, tau=1.0):
460
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
461
+ x.dtype
462
+ )
463
+ x = self.pre(x) * x_mask
464
+ x = self.enc(x, x_mask, g=g)
465
+ stats = self.proj(x) * x_mask
466
+ m, logs = torch.split(stats, self.out_channels, dim=1)
467
+ z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
468
+ return z, m, logs, x_mask
469
+
470
+
471
+ class Generator(torch.nn.Module):
472
+ def __init__(
473
+ self,
474
+ initial_channel,
475
+ resblock,
476
+ resblock_kernel_sizes,
477
+ resblock_dilation_sizes,
478
+ upsample_rates,
479
+ upsample_initial_channel,
480
+ upsample_kernel_sizes,
481
+ gin_channels=0,
482
+ ):
483
+ super(Generator, self).__init__()
484
+ self.num_kernels = len(resblock_kernel_sizes)
485
+ self.num_upsamples = len(upsample_rates)
486
+ self.conv_pre = Conv1d(
487
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
488
+ )
489
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
490
+
491
+ self.ups = nn.ModuleList()
492
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
493
+ self.ups.append(
494
+ weight_norm(
495
+ ConvTranspose1d(
496
+ upsample_initial_channel // (2**i),
497
+ upsample_initial_channel // (2 ** (i + 1)),
498
+ k,
499
+ u,
500
+ padding=(k - u) // 2,
501
+ )
502
+ )
503
+ )
504
+
505
+ self.resblocks = nn.ModuleList()
506
+ for i in range(len(self.ups)):
507
+ ch = upsample_initial_channel // (2 ** (i + 1))
508
+ for j, (k, d) in enumerate(
509
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
510
+ ):
511
+ self.resblocks.append(resblock(ch, k, d))
512
+
513
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
514
+ self.ups.apply(init_weights)
515
+
516
+ if gin_channels != 0:
517
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
518
+
519
+ def forward(self, x, g=None):
520
+ x = self.conv_pre(x)
521
+ if g is not None:
522
+ x = x + self.cond(g)
523
+
524
+ for i in range(self.num_upsamples):
525
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
526
+ x = self.ups[i](x)
527
+ xs = None
528
+ for j in range(self.num_kernels):
529
+ if xs is None:
530
+ xs = self.resblocks[i * self.num_kernels + j](x)
531
+ else:
532
+ xs += self.resblocks[i * self.num_kernels + j](x)
533
+ x = xs / self.num_kernels
534
+ x = F.leaky_relu(x)
535
+ x = self.conv_post(x)
536
+ x = torch.tanh(x)
537
+
538
+ return x
539
+
540
+ def remove_weight_norm(self):
541
+ print("Removing weight norm...")
542
+ for layer in self.ups:
543
+ remove_weight_norm(layer)
544
+ for layer in self.resblocks:
545
+ layer.remove_weight_norm()
546
+
547
+
548
+ class DiscriminatorP(torch.nn.Module):
549
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
550
+ super(DiscriminatorP, self).__init__()
551
+ self.period = period
552
+ self.use_spectral_norm = use_spectral_norm
553
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
554
+ self.convs = nn.ModuleList(
555
+ [
556
+ norm_f(
557
+ Conv2d(
558
+ 1,
559
+ 32,
560
+ (kernel_size, 1),
561
+ (stride, 1),
562
+ padding=(get_padding(kernel_size, 1), 0),
563
+ )
564
+ ),
565
+ norm_f(
566
+ Conv2d(
567
+ 32,
568
+ 128,
569
+ (kernel_size, 1),
570
+ (stride, 1),
571
+ padding=(get_padding(kernel_size, 1), 0),
572
+ )
573
+ ),
574
+ norm_f(
575
+ Conv2d(
576
+ 128,
577
+ 512,
578
+ (kernel_size, 1),
579
+ (stride, 1),
580
+ padding=(get_padding(kernel_size, 1), 0),
581
+ )
582
+ ),
583
+ norm_f(
584
+ Conv2d(
585
+ 512,
586
+ 1024,
587
+ (kernel_size, 1),
588
+ (stride, 1),
589
+ padding=(get_padding(kernel_size, 1), 0),
590
+ )
591
+ ),
592
+ norm_f(
593
+ Conv2d(
594
+ 1024,
595
+ 1024,
596
+ (kernel_size, 1),
597
+ 1,
598
+ padding=(get_padding(kernel_size, 1), 0),
599
+ )
600
+ ),
601
+ ]
602
+ )
603
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
604
+
605
+ def forward(self, x):
606
+ fmap = []
607
+
608
+ # 1d to 2d
609
+ b, c, t = x.shape
610
+ if t % self.period != 0: # pad first
611
+ n_pad = self.period - (t % self.period)
612
+ x = F.pad(x, (0, n_pad), "reflect")
613
+ t = t + n_pad
614
+ x = x.view(b, c, t // self.period, self.period)
615
+
616
+ for layer in self.convs:
617
+ x = layer(x)
618
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
619
+ fmap.append(x)
620
+ x = self.conv_post(x)
621
+ fmap.append(x)
622
+ x = torch.flatten(x, 1, -1)
623
+
624
+ return x, fmap
625
+
626
+
627
+ class DiscriminatorS(torch.nn.Module):
628
+ def __init__(self, use_spectral_norm=False):
629
+ super(DiscriminatorS, self).__init__()
630
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
631
+ self.convs = nn.ModuleList(
632
+ [
633
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
634
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
635
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
636
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
637
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
638
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
639
+ ]
640
+ )
641
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
642
+
643
+ def forward(self, x):
644
+ fmap = []
645
+
646
+ for layer in self.convs:
647
+ x = layer(x)
648
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
649
+ fmap.append(x)
650
+ x = self.conv_post(x)
651
+ fmap.append(x)
652
+ x = torch.flatten(x, 1, -1)
653
+
654
+ return x, fmap
655
+
656
+
657
+ class MultiPeriodDiscriminator(torch.nn.Module):
658
+ def __init__(self, use_spectral_norm=False):
659
+ super(MultiPeriodDiscriminator, self).__init__()
660
+ periods = [2, 3, 5, 7, 11]
661
+
662
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
663
+ discs = discs + [
664
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
665
+ ]
666
+ self.discriminators = nn.ModuleList(discs)
667
+
668
+ def forward(self, y, y_hat):
669
+ y_d_rs = []
670
+ y_d_gs = []
671
+ fmap_rs = []
672
+ fmap_gs = []
673
+ for i, d in enumerate(self.discriminators):
674
+ y_d_r, fmap_r = d(y)
675
+ y_d_g, fmap_g = d(y_hat)
676
+ y_d_rs.append(y_d_r)
677
+ y_d_gs.append(y_d_g)
678
+ fmap_rs.append(fmap_r)
679
+ fmap_gs.append(fmap_g)
680
+
681
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
682
+
683
+
684
+ class ReferenceEncoder(nn.Module):
685
+ """
686
+ inputs --- [N, Ty/r, n_mels*r] mels
687
+ outputs --- [N, ref_enc_gru_size]
688
+ """
689
+
690
+ def __init__(self, spec_channels, gin_channels=0, layernorm=False):
691
+ super().__init__()
692
+ self.spec_channels = spec_channels
693
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
694
+ K = len(ref_enc_filters)
695
+ filters = [1] + ref_enc_filters
696
+ convs = [
697
+ weight_norm(
698
+ nn.Conv2d(
699
+ in_channels=filters[i],
700
+ out_channels=filters[i + 1],
701
+ kernel_size=(3, 3),
702
+ stride=(2, 2),
703
+ padding=(1, 1),
704
+ )
705
+ )
706
+ for i in range(K)
707
+ ]
708
+ self.convs = nn.ModuleList(convs)
709
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
710
+
711
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
712
+ self.gru = nn.GRU(
713
+ input_size=ref_enc_filters[-1] * out_channels,
714
+ hidden_size=256 // 2,
715
+ batch_first=True,
716
+ )
717
+ self.proj = nn.Linear(128, gin_channels)
718
+ if layernorm:
719
+ self.layernorm = nn.LayerNorm(self.spec_channels)
720
+ print('[Ref Enc]: using layer norm')
721
+ else:
722
+ self.layernorm = None
723
+
724
+ def forward(self, inputs, mask=None):
725
+ N = inputs.size(0)
726
+
727
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
728
+ if self.layernorm is not None:
729
+ out = self.layernorm(out)
730
+
731
+ for conv in self.convs:
732
+ out = conv(out)
733
+ # out = wn(out)
734
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
735
+
736
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
737
+ T = out.size(1)
738
+ N = out.size(0)
739
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
740
+
741
+ self.gru.flatten_parameters()
742
+ memory, out = self.gru(out) # out --- [1, N, 128]
743
+
744
+ return self.proj(out.squeeze(0))
745
+
746
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
747
+ for i in range(n_convs):
748
+ L = (L - kernel_size + 2 * pad) // stride + 1
749
+ return L
750
+
751
+
752
+ class SynthesizerTrn(nn.Module):
753
+ """
754
+ Synthesizer for Training
755
+ """
756
+
757
+ def __init__(
758
+ self,
759
+ n_vocab,
760
+ spec_channels,
761
+ segment_size,
762
+ inter_channels,
763
+ hidden_channels,
764
+ filter_channels,
765
+ n_heads,
766
+ n_layers,
767
+ kernel_size,
768
+ p_dropout,
769
+ resblock,
770
+ resblock_kernel_sizes,
771
+ resblock_dilation_sizes,
772
+ upsample_rates,
773
+ upsample_initial_channel,
774
+ upsample_kernel_sizes,
775
+ n_speakers=256,
776
+ gin_channels=256,
777
+ use_sdp=True,
778
+ n_flow_layer=4,
779
+ n_layers_trans_flow=6,
780
+ flow_share_parameter=False,
781
+ use_transformer_flow=True,
782
+ use_vc=False,
783
+ num_languages=None,
784
+ num_tones=None,
785
+ norm_refenc=False,
786
+ **kwargs
787
+ ):
788
+ super().__init__()
789
+ self.n_vocab = n_vocab
790
+ self.spec_channels = spec_channels
791
+ self.inter_channels = inter_channels
792
+ self.hidden_channels = hidden_channels
793
+ self.filter_channels = filter_channels
794
+ self.n_heads = n_heads
795
+ self.n_layers = n_layers
796
+ self.kernel_size = kernel_size
797
+ self.p_dropout = p_dropout
798
+ self.resblock = resblock
799
+ self.resblock_kernel_sizes = resblock_kernel_sizes
800
+ self.resblock_dilation_sizes = resblock_dilation_sizes
801
+ self.upsample_rates = upsample_rates
802
+ self.upsample_initial_channel = upsample_initial_channel
803
+ self.upsample_kernel_sizes = upsample_kernel_sizes
804
+ self.segment_size = segment_size
805
+ self.n_speakers = n_speakers
806
+ self.gin_channels = gin_channels
807
+ self.n_layers_trans_flow = n_layers_trans_flow
808
+ self.use_spk_conditioned_encoder = kwargs.get(
809
+ "use_spk_conditioned_encoder", True
810
+ )
811
+ self.use_sdp = use_sdp
812
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
813
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
814
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
815
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
816
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
817
+ self.enc_gin_channels = gin_channels
818
+ else:
819
+ self.enc_gin_channels = 0
820
+ self.enc_p = TextEncoder(
821
+ n_vocab,
822
+ inter_channels,
823
+ hidden_channels,
824
+ filter_channels,
825
+ n_heads,
826
+ n_layers,
827
+ kernel_size,
828
+ p_dropout,
829
+ gin_channels=self.enc_gin_channels,
830
+ num_languages=num_languages,
831
+ num_tones=num_tones,
832
+ )
833
+ self.dec = Generator(
834
+ inter_channels,
835
+ resblock,
836
+ resblock_kernel_sizes,
837
+ resblock_dilation_sizes,
838
+ upsample_rates,
839
+ upsample_initial_channel,
840
+ upsample_kernel_sizes,
841
+ gin_channels=gin_channels,
842
+ )
843
+ self.enc_q = PosteriorEncoder(
844
+ spec_channels,
845
+ inter_channels,
846
+ hidden_channels,
847
+ 5,
848
+ 1,
849
+ 16,
850
+ gin_channels=gin_channels,
851
+ )
852
+ if use_transformer_flow:
853
+ self.flow = TransformerCouplingBlock(
854
+ inter_channels,
855
+ hidden_channels,
856
+ filter_channels,
857
+ n_heads,
858
+ n_layers_trans_flow,
859
+ 5,
860
+ p_dropout,
861
+ n_flow_layer,
862
+ gin_channels=gin_channels,
863
+ share_parameter=flow_share_parameter,
864
+ )
865
+ else:
866
+ self.flow = ResidualCouplingBlock(
867
+ inter_channels,
868
+ hidden_channels,
869
+ 5,
870
+ 1,
871
+ n_flow_layer,
872
+ gin_channels=gin_channels,
873
+ )
874
+ self.sdp = StochasticDurationPredictor(
875
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
876
+ )
877
+ self.dp = DurationPredictor(
878
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
879
+ )
880
+
881
+ if n_speakers > 0:
882
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
883
+ else:
884
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
885
+ self.use_vc = use_vc
886
+
887
+
888
+ def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
889
+ if self.n_speakers > 0:
890
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
891
+ else:
892
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
893
+ if self.use_vc:
894
+ g_p = None
895
+ else:
896
+ g_p = g
897
+ x, m_p, logs_p, x_mask = self.enc_p(
898
+ x, x_lengths, tone, language, bert, ja_bert, g=g_p
899
+ )
900
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
901
+ z_p = self.flow(z, y_mask, g=g)
902
+
903
+ with torch.no_grad():
904
+ # negative cross-entropy
905
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
906
+ neg_cent1 = torch.sum(
907
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
908
+ ) # [b, 1, t_s]
909
+ neg_cent2 = torch.matmul(
910
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
911
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
912
+ neg_cent3 = torch.matmul(
913
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
914
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
915
+ neg_cent4 = torch.sum(
916
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
917
+ ) # [b, 1, t_s]
918
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
919
+ if self.use_noise_scaled_mas:
920
+ epsilon = (
921
+ torch.std(neg_cent)
922
+ * torch.randn_like(neg_cent)
923
+ * self.current_mas_noise_scale
924
+ )
925
+ neg_cent = neg_cent + epsilon
926
+
927
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
928
+ attn = (
929
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
930
+ .unsqueeze(1)
931
+ .detach()
932
+ )
933
+
934
+ w = attn.sum(2)
935
+
936
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
937
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
938
+
939
+ logw_ = torch.log(w + 1e-6) * x_mask
940
+ logw = self.dp(x, x_mask, g=g)
941
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
942
+ x_mask
943
+ ) # for averaging
944
+
945
+ l_length = l_length_dp + l_length_sdp
946
+
947
+ # expand prior
948
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
949
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
950
+
951
+ z_slice, ids_slice = commons.rand_slice_segments(
952
+ z, y_lengths, self.segment_size
953
+ )
954
+ o = self.dec(z_slice, g=g)
955
+ return (
956
+ o,
957
+ l_length,
958
+ attn,
959
+ ids_slice,
960
+ x_mask,
961
+ y_mask,
962
+ (z, z_p, m_p, logs_p, m_q, logs_q),
963
+ (x, logw, logw_),
964
+ )
965
+
966
+ def infer(
967
+ self,
968
+ x,
969
+ x_lengths,
970
+ sid,
971
+ tone,
972
+ language,
973
+ bert,
974
+ ja_bert,
975
+ noise_scale=0.667,
976
+ length_scale=1,
977
+ noise_scale_w=0.8,
978
+ max_len=None,
979
+ sdp_ratio=0,
980
+ y=None,
981
+ g=None,
982
+ ):
983
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
984
+ # g = self.gst(y)
985
+ if g is None:
986
+ if self.n_speakers > 0:
987
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
988
+ else:
989
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
990
+ if self.use_vc:
991
+ g_p = None
992
+ else:
993
+ g_p = g
994
+ x, m_p, logs_p, x_mask = self.enc_p(
995
+ x, x_lengths, tone, language, bert, ja_bert, g=g_p
996
+ )
997
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
998
+ sdp_ratio
999
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1000
+ w = torch.exp(logw) * x_mask * length_scale
1001
+
1002
+ w_ceil = torch.ceil(w)
1003
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1004
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1005
+ x_mask.dtype
1006
+ )
1007
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1008
+ attn = commons.generate_path(w_ceil, attn_mask)
1009
+
1010
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1011
+ 1, 2
1012
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1013
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1014
+ 1, 2
1015
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1016
+
1017
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1018
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1019
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1020
+ # print('max/min of o:', o.max(), o.min())
1021
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
1022
+
1023
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
1024
+ g_src = sid_src
1025
+ g_tgt = sid_tgt
1026
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
1027
+ z_p = self.flow(z, y_mask, g=g_src)
1028
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
1029
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
1030
+ return o_hat, y_mask, (z, z_p, z_hat)
MeloTTS/melo/modules.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ from . import commons
10
+ from .commons import init_weights, get_padding
11
+ from .transforms import piecewise_rational_quadratic_transform
12
+ from .attentions import Encoder
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ hidden_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ n_layers,
40
+ p_dropout,
41
+ ):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ self.hidden_channels = hidden_channels
45
+ self.out_channels = out_channels
46
+ self.kernel_size = kernel_size
47
+ self.n_layers = n_layers
48
+ self.p_dropout = p_dropout
49
+ assert n_layers > 1, "Number of layers should be larger than 0."
50
+
51
+ self.conv_layers = nn.ModuleList()
52
+ self.norm_layers = nn.ModuleList()
53
+ self.conv_layers.append(
54
+ nn.Conv1d(
55
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56
+ )
57
+ )
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60
+ for _ in range(n_layers - 1):
61
+ self.conv_layers.append(
62
+ nn.Conv1d(
63
+ hidden_channels,
64
+ hidden_channels,
65
+ kernel_size,
66
+ padding=kernel_size // 2,
67
+ )
68
+ )
69
+ self.norm_layers.append(LayerNorm(hidden_channels))
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71
+ self.proj.weight.data.zero_()
72
+ self.proj.bias.data.zero_()
73
+
74
+ def forward(self, x, x_mask):
75
+ x_org = x
76
+ for i in range(self.n_layers):
77
+ x = self.conv_layers[i](x * x_mask)
78
+ x = self.norm_layers[i](x)
79
+ x = self.relu_drop(x)
80
+ x = x_org + self.proj(x)
81
+ return x * x_mask
82
+
83
+
84
+ class DDSConv(nn.Module):
85
+ """
86
+ Dialted and Depth-Separable Convolution
87
+ """
88
+
89
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
90
+ super().__init__()
91
+ self.channels = channels
92
+ self.kernel_size = kernel_size
93
+ self.n_layers = n_layers
94
+ self.p_dropout = p_dropout
95
+
96
+ self.drop = nn.Dropout(p_dropout)
97
+ self.convs_sep = nn.ModuleList()
98
+ self.convs_1x1 = nn.ModuleList()
99
+ self.norms_1 = nn.ModuleList()
100
+ self.norms_2 = nn.ModuleList()
101
+ for i in range(n_layers):
102
+ dilation = kernel_size**i
103
+ padding = (kernel_size * dilation - dilation) // 2
104
+ self.convs_sep.append(
105
+ nn.Conv1d(
106
+ channels,
107
+ channels,
108
+ kernel_size,
109
+ groups=channels,
110
+ dilation=dilation,
111
+ padding=padding,
112
+ )
113
+ )
114
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
115
+ self.norms_1.append(LayerNorm(channels))
116
+ self.norms_2.append(LayerNorm(channels))
117
+
118
+ def forward(self, x, x_mask, g=None):
119
+ if g is not None:
120
+ x = x + g
121
+ for i in range(self.n_layers):
122
+ y = self.convs_sep[i](x * x_mask)
123
+ y = self.norms_1[i](y)
124
+ y = F.gelu(y)
125
+ y = self.convs_1x1[i](y)
126
+ y = self.norms_2[i](y)
127
+ y = F.gelu(y)
128
+ y = self.drop(y)
129
+ x = x + y
130
+ return x * x_mask
131
+
132
+
133
+ class WN(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=0,
141
+ p_dropout=0,
142
+ ):
143
+ super(WN, self).__init__()
144
+ assert kernel_size % 2 == 1
145
+ self.hidden_channels = hidden_channels
146
+ self.kernel_size = (kernel_size,)
147
+ self.dilation_rate = dilation_rate
148
+ self.n_layers = n_layers
149
+ self.gin_channels = gin_channels
150
+ self.p_dropout = p_dropout
151
+
152
+ self.in_layers = torch.nn.ModuleList()
153
+ self.res_skip_layers = torch.nn.ModuleList()
154
+ self.drop = nn.Dropout(p_dropout)
155
+
156
+ if gin_channels != 0:
157
+ cond_layer = torch.nn.Conv1d(
158
+ gin_channels, 2 * hidden_channels * n_layers, 1
159
+ )
160
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161
+
162
+ for i in range(n_layers):
163
+ dilation = dilation_rate**i
164
+ padding = int((kernel_size * dilation - dilation) / 2)
165
+ in_layer = torch.nn.Conv1d(
166
+ hidden_channels,
167
+ 2 * hidden_channels,
168
+ kernel_size,
169
+ dilation=dilation,
170
+ padding=padding,
171
+ )
172
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
173
+ self.in_layers.append(in_layer)
174
+
175
+ # last one is not necessary
176
+ if i < n_layers - 1:
177
+ res_skip_channels = 2 * hidden_channels
178
+ else:
179
+ res_skip_channels = hidden_channels
180
+
181
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
182
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
183
+ self.res_skip_layers.append(res_skip_layer)
184
+
185
+ def forward(self, x, x_mask, g=None, **kwargs):
186
+ output = torch.zeros_like(x)
187
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
188
+
189
+ if g is not None:
190
+ g = self.cond_layer(g)
191
+
192
+ for i in range(self.n_layers):
193
+ x_in = self.in_layers[i](x)
194
+ if g is not None:
195
+ cond_offset = i * 2 * self.hidden_channels
196
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
197
+ else:
198
+ g_l = torch.zeros_like(x_in)
199
+
200
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
201
+ acts = self.drop(acts)
202
+
203
+ res_skip_acts = self.res_skip_layers[i](acts)
204
+ if i < self.n_layers - 1:
205
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
206
+ x = (x + res_acts) * x_mask
207
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
208
+ else:
209
+ output = output + res_skip_acts
210
+ return output * x_mask
211
+
212
+ def remove_weight_norm(self):
213
+ if self.gin_channels != 0:
214
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
215
+ for l in self.in_layers:
216
+ torch.nn.utils.remove_weight_norm(l)
217
+ for l in self.res_skip_layers:
218
+ torch.nn.utils.remove_weight_norm(l)
219
+
220
+
221
+ class ResBlock1(torch.nn.Module):
222
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
223
+ super(ResBlock1, self).__init__()
224
+ self.convs1 = nn.ModuleList(
225
+ [
226
+ weight_norm(
227
+ Conv1d(
228
+ channels,
229
+ channels,
230
+ kernel_size,
231
+ 1,
232
+ dilation=dilation[0],
233
+ padding=get_padding(kernel_size, dilation[0]),
234
+ )
235
+ ),
236
+ weight_norm(
237
+ Conv1d(
238
+ channels,
239
+ channels,
240
+ kernel_size,
241
+ 1,
242
+ dilation=dilation[1],
243
+ padding=get_padding(kernel_size, dilation[1]),
244
+ )
245
+ ),
246
+ weight_norm(
247
+ Conv1d(
248
+ channels,
249
+ channels,
250
+ kernel_size,
251
+ 1,
252
+ dilation=dilation[2],
253
+ padding=get_padding(kernel_size, dilation[2]),
254
+ )
255
+ ),
256
+ ]
257
+ )
258
+ self.convs1.apply(init_weights)
259
+
260
+ self.convs2 = nn.ModuleList(
261
+ [
262
+ weight_norm(
263
+ Conv1d(
264
+ channels,
265
+ channels,
266
+ kernel_size,
267
+ 1,
268
+ dilation=1,
269
+ padding=get_padding(kernel_size, 1),
270
+ )
271
+ ),
272
+ weight_norm(
273
+ Conv1d(
274
+ channels,
275
+ channels,
276
+ kernel_size,
277
+ 1,
278
+ dilation=1,
279
+ padding=get_padding(kernel_size, 1),
280
+ )
281
+ ),
282
+ weight_norm(
283
+ Conv1d(
284
+ channels,
285
+ channels,
286
+ kernel_size,
287
+ 1,
288
+ dilation=1,
289
+ padding=get_padding(kernel_size, 1),
290
+ )
291
+ ),
292
+ ]
293
+ )
294
+ self.convs2.apply(init_weights)
295
+
296
+ def forward(self, x, x_mask=None):
297
+ for c1, c2 in zip(self.convs1, self.convs2):
298
+ xt = F.leaky_relu(x, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c1(xt)
302
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
303
+ if x_mask is not None:
304
+ xt = xt * x_mask
305
+ xt = c2(xt)
306
+ x = xt + x
307
+ if x_mask is not None:
308
+ x = x * x_mask
309
+ return x
310
+
311
+ def remove_weight_norm(self):
312
+ for l in self.convs1:
313
+ remove_weight_norm(l)
314
+ for l in self.convs2:
315
+ remove_weight_norm(l)
316
+
317
+
318
+ class ResBlock2(torch.nn.Module):
319
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
320
+ super(ResBlock2, self).__init__()
321
+ self.convs = nn.ModuleList(
322
+ [
323
+ weight_norm(
324
+ Conv1d(
325
+ channels,
326
+ channels,
327
+ kernel_size,
328
+ 1,
329
+ dilation=dilation[0],
330
+ padding=get_padding(kernel_size, dilation[0]),
331
+ )
332
+ ),
333
+ weight_norm(
334
+ Conv1d(
335
+ channels,
336
+ channels,
337
+ kernel_size,
338
+ 1,
339
+ dilation=dilation[1],
340
+ padding=get_padding(kernel_size, dilation[1]),
341
+ )
342
+ ),
343
+ ]
344
+ )
345
+ self.convs.apply(init_weights)
346
+
347
+ def forward(self, x, x_mask=None):
348
+ for c in self.convs:
349
+ xt = F.leaky_relu(x, LRELU_SLOPE)
350
+ if x_mask is not None:
351
+ xt = xt * x_mask
352
+ xt = c(xt)
353
+ x = xt + x
354
+ if x_mask is not None:
355
+ x = x * x_mask
356
+ return x
357
+
358
+ def remove_weight_norm(self):
359
+ for l in self.convs:
360
+ remove_weight_norm(l)
361
+
362
+
363
+ class Log(nn.Module):
364
+ def forward(self, x, x_mask, reverse=False, **kwargs):
365
+ if not reverse:
366
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
367
+ logdet = torch.sum(-y, [1, 2])
368
+ return y, logdet
369
+ else:
370
+ x = torch.exp(x) * x_mask
371
+ return x
372
+
373
+
374
+ class Flip(nn.Module):
375
+ def forward(self, x, *args, reverse=False, **kwargs):
376
+ x = torch.flip(x, [1])
377
+ if not reverse:
378
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
379
+ return x, logdet
380
+ else:
381
+ return x
382
+
383
+
384
+ class ElementwiseAffine(nn.Module):
385
+ def __init__(self, channels):
386
+ super().__init__()
387
+ self.channels = channels
388
+ self.m = nn.Parameter(torch.zeros(channels, 1))
389
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
390
+
391
+ def forward(self, x, x_mask, reverse=False, **kwargs):
392
+ if not reverse:
393
+ y = self.m + torch.exp(self.logs) * x
394
+ y = y * x_mask
395
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
396
+ return y, logdet
397
+ else:
398
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
399
+ return x
400
+
401
+
402
+ class ResidualCouplingLayer(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ p_dropout=0,
411
+ gin_channels=0,
412
+ mean_only=False,
413
+ ):
414
+ assert channels % 2 == 0, "channels should be divisible by 2"
415
+ super().__init__()
416
+ self.channels = channels
417
+ self.hidden_channels = hidden_channels
418
+ self.kernel_size = kernel_size
419
+ self.dilation_rate = dilation_rate
420
+ self.n_layers = n_layers
421
+ self.half_channels = channels // 2
422
+ self.mean_only = mean_only
423
+
424
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
425
+ self.enc = WN(
426
+ hidden_channels,
427
+ kernel_size,
428
+ dilation_rate,
429
+ n_layers,
430
+ p_dropout=p_dropout,
431
+ gin_channels=gin_channels,
432
+ )
433
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
434
+ self.post.weight.data.zero_()
435
+ self.post.bias.data.zero_()
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
439
+ h = self.pre(x0) * x_mask
440
+ h = self.enc(h, x_mask, g=g)
441
+ stats = self.post(h) * x_mask
442
+ if not self.mean_only:
443
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
444
+ else:
445
+ m = stats
446
+ logs = torch.zeros_like(m)
447
+
448
+ if not reverse:
449
+ x1 = m + x1 * torch.exp(logs) * x_mask
450
+ x = torch.cat([x0, x1], 1)
451
+ logdet = torch.sum(logs, [1, 2])
452
+ return x, logdet
453
+ else:
454
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
455
+ x = torch.cat([x0, x1], 1)
456
+ return x
457
+
458
+
459
+ class ConvFlow(nn.Module):
460
+ def __init__(
461
+ self,
462
+ in_channels,
463
+ filter_channels,
464
+ kernel_size,
465
+ n_layers,
466
+ num_bins=10,
467
+ tail_bound=5.0,
468
+ ):
469
+ super().__init__()
470
+ self.in_channels = in_channels
471
+ self.filter_channels = filter_channels
472
+ self.kernel_size = kernel_size
473
+ self.n_layers = n_layers
474
+ self.num_bins = num_bins
475
+ self.tail_bound = tail_bound
476
+ self.half_channels = in_channels // 2
477
+
478
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
479
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
480
+ self.proj = nn.Conv1d(
481
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
482
+ )
483
+ self.proj.weight.data.zero_()
484
+ self.proj.bias.data.zero_()
485
+
486
+ def forward(self, x, x_mask, g=None, reverse=False):
487
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
488
+ h = self.pre(x0)
489
+ h = self.convs(h, x_mask, g=g)
490
+ h = self.proj(h) * x_mask
491
+
492
+ b, c, t = x0.shape
493
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
494
+
495
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
496
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
497
+ self.filter_channels
498
+ )
499
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
500
+
501
+ x1, logabsdet = piecewise_rational_quadratic_transform(
502
+ x1,
503
+ unnormalized_widths,
504
+ unnormalized_heights,
505
+ unnormalized_derivatives,
506
+ inverse=reverse,
507
+ tails="linear",
508
+ tail_bound=self.tail_bound,
509
+ )
510
+
511
+ x = torch.cat([x0, x1], 1) * x_mask
512
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
513
+ if not reverse:
514
+ return x, logdet
515
+ else:
516
+ return x
517
+
518
+
519
+ class TransformerCouplingLayer(nn.Module):
520
+ def __init__(
521
+ self,
522
+ channels,
523
+ hidden_channels,
524
+ kernel_size,
525
+ n_layers,
526
+ n_heads,
527
+ p_dropout=0,
528
+ filter_channels=0,
529
+ mean_only=False,
530
+ wn_sharing_parameter=None,
531
+ gin_channels=0,
532
+ ):
533
+ assert n_layers == 3, n_layers
534
+ assert channels % 2 == 0, "channels should be divisible by 2"
535
+ super().__init__()
536
+ self.channels = channels
537
+ self.hidden_channels = hidden_channels
538
+ self.kernel_size = kernel_size
539
+ self.n_layers = n_layers
540
+ self.half_channels = channels // 2
541
+ self.mean_only = mean_only
542
+
543
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
544
+ self.enc = (
545
+ Encoder(
546
+ hidden_channels,
547
+ filter_channels,
548
+ n_heads,
549
+ n_layers,
550
+ kernel_size,
551
+ p_dropout,
552
+ isflow=True,
553
+ gin_channels=gin_channels,
554
+ )
555
+ if wn_sharing_parameter is None
556
+ else wn_sharing_parameter
557
+ )
558
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
559
+ self.post.weight.data.zero_()
560
+ self.post.bias.data.zero_()
561
+
562
+ def forward(self, x, x_mask, g=None, reverse=False):
563
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
564
+ h = self.pre(x0) * x_mask
565
+ h = self.enc(h, x_mask, g=g)
566
+ stats = self.post(h) * x_mask
567
+ if not self.mean_only:
568
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
569
+ else:
570
+ m = stats
571
+ logs = torch.zeros_like(m)
572
+
573
+ if not reverse:
574
+ x1 = m + x1 * torch.exp(logs) * x_mask
575
+ x = torch.cat([x0, x1], 1)
576
+ logdet = torch.sum(logs, [1, 2])
577
+ return x, logdet
578
+ else:
579
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
580
+ x = torch.cat([x0, x1], 1)
581
+ return x
582
+
583
+ x1, logabsdet = piecewise_rational_quadratic_transform(
584
+ x1,
585
+ unnormalized_widths,
586
+ unnormalized_heights,
587
+ unnormalized_derivatives,
588
+ inverse=reverse,
589
+ tails="linear",
590
+ tail_bound=self.tail_bound,
591
+ )
592
+
593
+ x = torch.cat([x0, x1], 1) * x_mask
594
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
595
+ if not reverse:
596
+ return x, logdet
597
+ else:
598
+ return x
MeloTTS/melo/monotonic_align/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+
7
+ def maximum_path(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
MeloTTS/melo/monotonic_align/core.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def maximum_path_jit(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
MeloTTS/melo/preprocess_text.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+ from random import shuffle
4
+ from typing import Optional
5
+
6
+ from tqdm import tqdm
7
+ import click
8
+ from text.cleaner import clean_text_bert
9
+ import os
10
+ import torch
11
+ from text.symbols import symbols, num_languages, num_tones
12
+
13
+ @click.command()
14
+ @click.option(
15
+ "--metadata",
16
+ default="data/example/metadata.list",
17
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
18
+ )
19
+ @click.option("--cleaned-path", default=None)
20
+ @click.option("--train-path", default=None)
21
+ @click.option("--val-path", default=None)
22
+ @click.option(
23
+ "--config_path",
24
+ default="configs/config.json",
25
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
26
+ )
27
+ @click.option("--val-per-spk", default=4)
28
+ @click.option("--max-val-total", default=8)
29
+ @click.option("--clean/--no-clean", default=True)
30
+ def main(
31
+ metadata: str,
32
+ cleaned_path: Optional[str],
33
+ train_path: str,
34
+ val_path: str,
35
+ config_path: str,
36
+ val_per_spk: int,
37
+ max_val_total: int,
38
+ clean: bool,
39
+ ):
40
+ if train_path is None:
41
+ train_path = os.path.join(os.path.dirname(metadata), 'train.list')
42
+ if val_path is None:
43
+ val_path = os.path.join(os.path.dirname(metadata), 'val.list')
44
+ out_config_path = os.path.join(os.path.dirname(metadata), 'config.json')
45
+
46
+ if cleaned_path is None:
47
+ cleaned_path = metadata + ".cleaned"
48
+
49
+ if clean:
50
+ out_file = open(cleaned_path, "w", encoding="utf-8")
51
+ new_symbols = []
52
+ for line in tqdm(open(metadata, encoding="utf-8").readlines()):
53
+ try:
54
+ utt, spk, language, text = line.strip().split("|")
55
+ norm_text, phones, tones, word2ph, bert = clean_text_bert(text, language, device='cuda:0')
56
+ for ph in phones:
57
+ if ph not in symbols and ph not in new_symbols:
58
+ new_symbols.append(ph)
59
+ print('update!, now symbols:')
60
+ print(new_symbols)
61
+ with open(f'{language}_symbol.txt', 'w') as f:
62
+ f.write(f'{new_symbols}')
63
+
64
+ assert len(phones) == len(tones)
65
+ assert len(phones) == sum(word2ph)
66
+ out_file.write(
67
+ "{}|{}|{}|{}|{}|{}|{}\n".format(
68
+ utt,
69
+ spk,
70
+ language,
71
+ norm_text,
72
+ " ".join(phones),
73
+ " ".join([str(i) for i in tones]),
74
+ " ".join([str(i) for i in word2ph]),
75
+ )
76
+ )
77
+ bert_path = utt.replace(".wav", ".bert.pt")
78
+ os.makedirs(os.path.dirname(bert_path), exist_ok=True)
79
+ torch.save(bert.cpu(), bert_path)
80
+ except Exception as error:
81
+ print("err!", line, error)
82
+
83
+ out_file.close()
84
+
85
+ metadata = cleaned_path
86
+
87
+ spk_utt_map = defaultdict(list)
88
+ spk_id_map = {}
89
+ current_sid = 0
90
+
91
+ with open(metadata, encoding="utf-8") as f:
92
+ for line in f.readlines():
93
+ utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
94
+ spk_utt_map[spk].append(line)
95
+
96
+ if spk not in spk_id_map.keys():
97
+ spk_id_map[spk] = current_sid
98
+ current_sid += 1
99
+
100
+ train_list = []
101
+ val_list = []
102
+
103
+ for spk, utts in spk_utt_map.items():
104
+ shuffle(utts)
105
+ val_list += utts[:val_per_spk]
106
+ train_list += utts[val_per_spk:]
107
+
108
+ if len(val_list) > max_val_total:
109
+ train_list += val_list[max_val_total:]
110
+ val_list = val_list[:max_val_total]
111
+
112
+ with open(train_path, "w", encoding="utf-8") as f:
113
+ for line in train_list:
114
+ f.write(line)
115
+
116
+ with open(val_path, "w", encoding="utf-8") as f:
117
+ for line in val_list:
118
+ f.write(line)
119
+
120
+ config = json.load(open(config_path, encoding="utf-8"))
121
+ config["data"]["spk2id"] = spk_id_map
122
+
123
+ config["data"]["training_files"] = train_path
124
+ config["data"]["validation_files"] = val_path
125
+ config["data"]["n_speakers"] = len(spk_id_map)
126
+ config["num_languages"] = num_languages
127
+ config["num_tones"] = num_tones
128
+ config["symbols"] = symbols
129
+
130
+ with open(out_config_path, "w", encoding="utf-8") as f:
131
+ json.dump(config, f, indent=2, ensure_ascii=False)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
MeloTTS/melo/split_utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import glob
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import torchaudio
7
+ import re
8
+
9
+ def split_sentence(text, min_len=10, language_str='EN'):
10
+ if language_str in ['EN', 'FR', 'ES', 'SP']:
11
+ sentences = split_sentences_latin(text, min_len=min_len)
12
+ else:
13
+ sentences = split_sentences_zh(text, min_len=min_len)
14
+ return sentences
15
+
16
+
17
+ def split_sentences_latin(text, min_len=10):
18
+ text = re.sub('[。!?;]', '.', text)
19
+ text = re.sub('[,]', ',', text)
20
+ text = re.sub('[“”]', '"', text)
21
+ text = re.sub('[‘’]', "'", text)
22
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
23
+ return [item.strip() for item in txtsplit(text, 256, 512) if item.strip()]
24
+
25
+
26
+ def split_sentences_zh(text, min_len=10):
27
+ text = re.sub('[。!?;]', '.', text)
28
+ text = re.sub('[,]', ',', text)
29
+ # 将文本中的换行符、空格和制表符替换为空格
30
+ text = re.sub('[\n\t ]+', ' ', text)
31
+ # 在标点符号后添加一个空格
32
+ text = re.sub('([,.!?;])', r'\1 $#!', text)
33
+ # 分隔句子并去除前后空格
34
+ # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
35
+ sentences = [s.strip() for s in text.split('$#!')]
36
+ if len(sentences[-1]) == 0: del sentences[-1]
37
+
38
+ new_sentences = []
39
+ new_sent = []
40
+ count_len = 0
41
+ for ind, sent in enumerate(sentences):
42
+ new_sent.append(sent)
43
+ count_len += len(sent)
44
+ if count_len > min_len or ind == len(sentences) - 1:
45
+ count_len = 0
46
+ new_sentences.append(' '.join(new_sent))
47
+ new_sent = []
48
+ return merge_short_sentences_zh(new_sentences)
49
+
50
+
51
+ def merge_short_sentences_en(sens):
52
+ """Avoid short sentences by merging them with the following sentence.
53
+
54
+ Args:
55
+ List[str]: list of input sentences.
56
+
57
+ Returns:
58
+ List[str]: list of output sentences.
59
+ """
60
+ sens_out = []
61
+ for s in sens:
62
+ # If the previous sentense is too short, merge them with
63
+ # the current sentence.
64
+ if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
65
+ sens_out[-1] = sens_out[-1] + " " + s
66
+ else:
67
+ sens_out.append(s)
68
+ try:
69
+ if len(sens_out[-1].split(" ")) <= 2:
70
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
71
+ sens_out.pop(-1)
72
+ except:
73
+ pass
74
+ return sens_out
75
+
76
+
77
+ def merge_short_sentences_zh(sens):
78
+ # return sens
79
+ """Avoid short sentences by merging them with the following sentence.
80
+
81
+ Args:
82
+ List[str]: list of input sentences.
83
+
84
+ Returns:
85
+ List[str]: list of output sentences.
86
+ """
87
+ sens_out = []
88
+ for s in sens:
89
+ # If the previous sentense is too short, merge them with
90
+ # the current sentence.
91
+ if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
92
+ sens_out[-1] = sens_out[-1] + " " + s
93
+ else:
94
+ sens_out.append(s)
95
+ try:
96
+ if len(sens_out[-1]) <= 2:
97
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
98
+ sens_out.pop(-1)
99
+ except:
100
+ pass
101
+ return sens_out
102
+
103
+
104
+
105
+ def txtsplit(text, desired_length=100, max_length=200):
106
+ """Split text it into chunks of a desired length trying to keep sentences intact."""
107
+ text = re.sub(r'\n\n+', '\n', text)
108
+ text = re.sub(r'\s+', ' ', text)
109
+ text = re.sub(r'[""]', '"', text)
110
+ text = re.sub(r'([,.?!])', r'\1 ', text)
111
+ text = re.sub(r'\s+', ' ', text)
112
+
113
+ rv = []
114
+ in_quote = False
115
+ current = ""
116
+ split_pos = []
117
+ pos = -1
118
+ end_pos = len(text) - 1
119
+ def seek(delta):
120
+ nonlocal pos, in_quote, current
121
+ is_neg = delta < 0
122
+ for _ in range(abs(delta)):
123
+ if is_neg:
124
+ pos -= 1
125
+ current = current[:-1]
126
+ else:
127
+ pos += 1
128
+ current += text[pos]
129
+ if text[pos] == '"':
130
+ in_quote = not in_quote
131
+ return text[pos]
132
+ def peek(delta):
133
+ p = pos + delta
134
+ return text[p] if p < end_pos and p >= 0 else ""
135
+ def commit():
136
+ nonlocal rv, current, split_pos
137
+ rv.append(current)
138
+ current = ""
139
+ split_pos = []
140
+ while pos < end_pos:
141
+ c = seek(1)
142
+ if len(current) >= max_length:
143
+ if len(split_pos) > 0 and len(current) > (desired_length / 2):
144
+ d = pos - split_pos[-1]
145
+ seek(-d)
146
+ else:
147
+ while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
148
+ c = seek(-1)
149
+ commit()
150
+ elif not in_quote and (c in '!?\n' or (c in '.,' and peek(1) in '\n ')):
151
+ while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.':
152
+ c = seek(1)
153
+ split_pos.append(pos)
154
+ if len(current) >= desired_length:
155
+ commit()
156
+ elif in_quote and peek(1) == '"' and peek(2) in '\n ':
157
+ seek(2)
158
+ split_pos.append(pos)
159
+ rv.append(current)
160
+ rv = [s.strip() for s in rv]
161
+ rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)]
162
+ return rv
163
+
164
+
165
+ if __name__ == '__main__':
166
+ zh_text = "好的,我来给你讲一个故事吧。从前有一个小姑娘,她叫做小红。小红非常喜欢在森林里玩耍,她经常会和她的小伙伴们一起去探险。有一天,小红和她的小伙伴们走到了森林深处,突然遇到了一只凶猛的野兽。小红的小伙伴们都吓得不敢动弹,但是小红并没有被吓倒,她勇敢地走向野兽,用她的智慧和勇气成功地制服了野兽,保护了她的小伙伴们。从那以后,小红变得更加勇敢和自信,成为了她小伙伴们心中的英雄。"
167
+ en_text = "I didn’t know what to do. I said please kill her because it would be better than being kidnapped,” Ben, whose surname CNN is not using for security concerns, said on Wednesday. “It’s a nightmare. I said ‘please kill her, don’t take her there.’"
168
+ sp_text = "¡Claro! ¿En qué tema te gustaría que te hable en español? Puedo proporcionarte información o conversar contigo sobre una amplia variedad de temas, desde cultura y comida hasta viajes y tecnología. ¿Tienes alguna preferencia en particular?"
169
+ fr_text = "Bien sûr ! En quelle matière voudriez-vous que je vous parle en français ? Je peux vous fournir des informations ou discuter avec vous sur une grande variété de sujets, que ce soit la culture, la nourriture, les voyages ou la technologie. Avez-vous une préférence particulière ?"
170
+
171
+ print(split_sentence(zh_text, language_str='ZH'))
172
+ print(split_sentence(en_text, language_str='EN'))
173
+ print(split_sentence(sp_text, language_str='SP'))
174
+ print(split_sentence(fr_text, language_str='FR'))
MeloTTS/melo/text/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .symbols import *
2
+
3
+
4
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5
+
6
+
7
+ def cleaned_text_to_sequence(cleaned_text, tones, language, symbol_to_id=None):
8
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
9
+ Args:
10
+ text: string to convert to a sequence
11
+ Returns:
12
+ List of integers corresponding to the symbols in the text
13
+ """
14
+ symbol_to_id_map = symbol_to_id if symbol_to_id else _symbol_to_id
15
+ phones = [symbol_to_id_map[symbol] for symbol in cleaned_text]
16
+ tone_start = language_tone_start_map[language]
17
+ tones = [i + tone_start for i in tones]
18
+ lang_id = language_id_map[language]
19
+ lang_ids = [lang_id for i in phones]
20
+ return phones, tones, lang_ids
21
+
22
+
23
+ def get_bert(norm_text, word2ph, language, device):
24
+ from .chinese_bert import get_bert_feature as zh_bert
25
+ from .english_bert import get_bert_feature as en_bert
26
+ from .japanese_bert import get_bert_feature as jp_bert
27
+ from .chinese_mix import get_bert_feature as zh_mix_en_bert
28
+ from .spanish_bert import get_bert_feature as sp_bert
29
+ from .french_bert import get_bert_feature as fr_bert
30
+ from .korean import get_bert_feature as kr_bert
31
+
32
+ lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert, 'ZH_MIX_EN': zh_mix_en_bert,
33
+ 'FR': fr_bert, 'SP': sp_bert, 'ES': sp_bert, "KR": kr_bert}
34
+ bert = lang_bert_func_map[language](norm_text, word2ph, device)
35
+ return bert
MeloTTS/melo/text/chinese.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+ from .symbols import punctuation
8
+ from .tone_sandhi import ToneSandhi
9
+
10
+ current_file_path = os.path.dirname(__file__)
11
+ pinyin_to_symbol_map = {
12
+ line.split("\t")[0]: line.strip().split("\t")[1]
13
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
14
+ }
15
+
16
+ import jieba.posseg as psg
17
+
18
+
19
+ rep_map = {
20
+ ":": ",",
21
+ ";": ",",
22
+ ",": ",",
23
+ "。": ".",
24
+ "!": "!",
25
+ "?": "?",
26
+ "\n": ".",
27
+ "·": ",",
28
+ "、": ",",
29
+ "...": "…",
30
+ "$": ".",
31
+ "“": "'",
32
+ "”": "'",
33
+ "‘": "'",
34
+ "’": "'",
35
+ "(": "'",
36
+ ")": "'",
37
+ "(": "'",
38
+ ")": "'",
39
+ "《": "'",
40
+ "》": "'",
41
+ "【": "'",
42
+ "】": "'",
43
+ "[": "'",
44
+ "]": "'",
45
+ "—": "-",
46
+ "~": "-",
47
+ "~": "-",
48
+ "「": "'",
49
+ "」": "'",
50
+ }
51
+
52
+ tone_modifier = ToneSandhi()
53
+
54
+
55
+ def replace_punctuation(text):
56
+ text = text.replace("嗯", "恩").replace("呣", "母")
57
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
58
+
59
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
60
+
61
+ replaced_text = re.sub(
62
+ r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
63
+ )
64
+
65
+ return replaced_text
66
+
67
+
68
+ def g2p(text):
69
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
70
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
71
+ phones, tones, word2ph = _g2p(sentences)
72
+ assert sum(word2ph) == len(phones)
73
+ assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
74
+ phones = ["_"] + phones + ["_"]
75
+ tones = [0] + tones + [0]
76
+ word2ph = [1] + word2ph + [1]
77
+ return phones, tones, word2ph
78
+
79
+
80
+ def _get_initials_finals(word):
81
+ initials = []
82
+ finals = []
83
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
84
+ orig_finals = lazy_pinyin(
85
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
86
+ )
87
+ for c, v in zip(orig_initials, orig_finals):
88
+ initials.append(c)
89
+ finals.append(v)
90
+ return initials, finals
91
+
92
+
93
+ def _g2p(segments):
94
+ phones_list = []
95
+ tones_list = []
96
+ word2ph = []
97
+ for seg in segments:
98
+ # Replace all English words in the sentence
99
+ seg = re.sub("[a-zA-Z]+", "", seg)
100
+ seg_cut = psg.lcut(seg)
101
+ initials = []
102
+ finals = []
103
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
104
+ for word, pos in seg_cut:
105
+ if pos == "eng":
106
+ import pdb; pdb.set_trace()
107
+ continue
108
+ sub_initials, sub_finals = _get_initials_finals(word)
109
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
110
+ initials.append(sub_initials)
111
+ finals.append(sub_finals)
112
+
113
+ # assert len(sub_initials) == len(sub_finals) == len(word)
114
+ initials = sum(initials, [])
115
+ finals = sum(finals, [])
116
+ #
117
+ for c, v in zip(initials, finals):
118
+ raw_pinyin = c + v
119
+ # NOTE: post process for pypinyin outputs
120
+ # we discriminate i, ii and iii
121
+ if c == v:
122
+ assert c in punctuation
123
+ phone = [c]
124
+ tone = "0"
125
+ word2ph.append(1)
126
+ else:
127
+ v_without_tone = v[:-1]
128
+ tone = v[-1]
129
+
130
+ pinyin = c + v_without_tone
131
+ assert tone in "12345"
132
+
133
+ if c:
134
+ # 多音节
135
+ v_rep_map = {
136
+ "uei": "ui",
137
+ "iou": "iu",
138
+ "uen": "un",
139
+ }
140
+ if v_without_tone in v_rep_map.keys():
141
+ pinyin = c + v_rep_map[v_without_tone]
142
+ else:
143
+ # 单音节
144
+ pinyin_rep_map = {
145
+ "ing": "ying",
146
+ "i": "yi",
147
+ "in": "yin",
148
+ "u": "wu",
149
+ }
150
+ if pinyin in pinyin_rep_map.keys():
151
+ pinyin = pinyin_rep_map[pinyin]
152
+ else:
153
+ single_rep_map = {
154
+ "v": "yu",
155
+ "e": "e",
156
+ "i": "y",
157
+ "u": "w",
158
+ }
159
+ if pinyin[0] in single_rep_map.keys():
160
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
161
+
162
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
163
+ phone = pinyin_to_symbol_map[pinyin].split(" ")
164
+ word2ph.append(len(phone))
165
+
166
+ phones_list += phone
167
+ tones_list += [int(tone)] * len(phone)
168
+ return phones_list, tones_list, word2ph
169
+
170
+
171
+ def text_normalize(text):
172
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
173
+ for number in numbers:
174
+ text = text.replace(number, cn2an.an2cn(number), 1)
175
+ text = replace_punctuation(text)
176
+ return text
177
+
178
+
179
+ def get_bert_feature(text, word2ph, device=None):
180
+ from text import chinese_bert
181
+
182
+ return chinese_bert.get_bert_feature(text, word2ph, device=device)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ from text.chinese_bert import get_bert_feature
187
+
188
+ text = "啊!chemistry 但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
189
+ text = text_normalize(text)
190
+ print(text)
191
+ phones, tones, word2ph = g2p(text)
192
+ bert = get_bert_feature(text, word2ph)
193
+
194
+ print(phones, tones, word2ph, bert.shape)
195
+
196
+
197
+ # # 示例用法
198
+ # text = "这是一个示例文本:,你好!这是一个测试...."
199
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
MeloTTS/melo/text/chinese_bert.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+
5
+
6
+ # model_id = 'hfl/chinese-roberta-wwm-ext-large'
7
+ local_path = "./bert/chinese-roberta-wwm-ext-large"
8
+
9
+
10
+ tokenizers = {}
11
+ models = {}
12
+
13
+ def get_bert_feature(text, word2ph, device=None, model_id='hfl/chinese-roberta-wwm-ext-large'):
14
+ if model_id not in models:
15
+ models[model_id] = AutoModelForMaskedLM.from_pretrained(
16
+ model_id
17
+ ).to(device)
18
+ tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
19
+ model = models[model_id]
20
+ tokenizer = tokenizers[model_id]
21
+
22
+ if (
23
+ sys.platform == "darwin"
24
+ and torch.backends.mps.is_available()
25
+ and device == "cpu"
26
+ ):
27
+ device = "mps"
28
+ if not device:
29
+ device = "cuda"
30
+
31
+ with torch.no_grad():
32
+ inputs = tokenizer(text, return_tensors="pt")
33
+ for i in inputs:
34
+ inputs[i] = inputs[i].to(device)
35
+ res = model(**inputs, output_hidden_states=True)
36
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
37
+ # import pdb; pdb.set_trace()
38
+ # assert len(word2ph) == len(text) + 2
39
+ word2phone = word2ph
40
+ phone_level_feature = []
41
+ for i in range(len(word2phone)):
42
+ repeat_feature = res[i].repeat(word2phone[i], 1)
43
+ phone_level_feature.append(repeat_feature)
44
+
45
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
46
+ return phone_level_feature.T
47
+
48
+
49
+ if __name__ == "__main__":
50
+ import torch
51
+
52
+ word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
53
+ word2phone = [
54
+ 1,
55
+ 2,
56
+ 1,
57
+ 2,
58
+ 2,
59
+ 1,
60
+ 2,
61
+ 2,
62
+ 1,
63
+ 2,
64
+ 2,
65
+ 1,
66
+ 2,
67
+ 2,
68
+ 2,
69
+ 2,
70
+ 2,
71
+ 1,
72
+ 1,
73
+ 2,
74
+ 2,
75
+ 1,
76
+ 2,
77
+ 2,
78
+ 2,
79
+ 2,
80
+ 1,
81
+ 2,
82
+ 2,
83
+ 2,
84
+ 2,
85
+ 2,
86
+ 1,
87
+ 2,
88
+ 2,
89
+ 2,
90
+ 2,
91
+ 1,
92
+ ]
93
+
94
+ # 计算总帧数
95
+ total_frames = sum(word2phone)
96
+ print(word_level_feature.shape)
97
+ print(word2phone)
98
+ phone_level_feature = []
99
+ for i in range(len(word2phone)):
100
+ print(word_level_feature[i].shape)
101
+
102
+ # 对每个词重复word2phone[i]次
103
+ repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
104
+ phone_level_feature.append(repeat_feature)
105
+
106
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
107
+ print(phone_level_feature.shape) # torch.Size([36, 1024])
MeloTTS/melo/text/chinese_mix.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+ # from text.symbols import punctuation
8
+ from .symbols import language_tone_start_map
9
+ from .tone_sandhi import ToneSandhi
10
+ from .english import g2p as g2p_en
11
+ from transformers import AutoTokenizer
12
+
13
+ punctuation = ["!", "?", "…", ",", ".", "'", "-"]
14
+ current_file_path = os.path.dirname(__file__)
15
+ pinyin_to_symbol_map = {
16
+ line.split("\t")[0]: line.strip().split("\t")[1]
17
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
18
+ }
19
+
20
+ import jieba.posseg as psg
21
+
22
+
23
+ rep_map = {
24
+ ":": ",",
25
+ ";": ",",
26
+ ",": ",",
27
+ "。": ".",
28
+ "!": "!",
29
+ "?": "?",
30
+ "\n": ".",
31
+ "·": ",",
32
+ "、": ",",
33
+ "...": "…",
34
+ "$": ".",
35
+ "“": "'",
36
+ "”": "'",
37
+ "‘": "'",
38
+ "’": "'",
39
+ "(": "'",
40
+ ")": "'",
41
+ "(": "'",
42
+ ")": "'",
43
+ "《": "'",
44
+ "》": "'",
45
+ "【": "'",
46
+ "】": "'",
47
+ "[": "'",
48
+ "]": "'",
49
+ "—": "-",
50
+ "~": "-",
51
+ "~": "-",
52
+ "「": "'",
53
+ "」": "'",
54
+ }
55
+
56
+ tone_modifier = ToneSandhi()
57
+
58
+
59
+ def replace_punctuation(text):
60
+ text = text.replace("嗯", "恩").replace("呣", "母")
61
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
62
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
63
+ replaced_text = re.sub(r"[^\u4e00-\u9fa5_a-zA-Z\s" + "".join(punctuation) + r"]+", "", replaced_text)
64
+ replaced_text = re.sub(r"[\s]+", " ", replaced_text)
65
+
66
+ return replaced_text
67
+
68
+
69
+ def g2p(text, impl='v2'):
70
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
71
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
72
+ if impl == 'v1':
73
+ _func = _g2p
74
+ elif impl == 'v2':
75
+ _func = _g2p_v2
76
+ else:
77
+ raise NotImplementedError()
78
+ phones, tones, word2ph = _func(sentences)
79
+ assert sum(word2ph) == len(phones)
80
+ # assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
81
+ phones = ["_"] + phones + ["_"]
82
+ tones = [0] + tones + [0]
83
+ word2ph = [1] + word2ph + [1]
84
+ return phones, tones, word2ph
85
+
86
+
87
+ def _get_initials_finals(word):
88
+ initials = []
89
+ finals = []
90
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
91
+ orig_finals = lazy_pinyin(
92
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
93
+ )
94
+ for c, v in zip(orig_initials, orig_finals):
95
+ initials.append(c)
96
+ finals.append(v)
97
+ return initials, finals
98
+
99
+ model_id = 'bert-base-multilingual-uncased'
100
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
101
+ def _g2p(segments):
102
+ phones_list = []
103
+ tones_list = []
104
+ word2ph = []
105
+ for seg in segments:
106
+ # Replace all English words in the sentence
107
+ # seg = re.sub("[a-zA-Z]+", "", seg)
108
+ seg_cut = psg.lcut(seg)
109
+ initials = []
110
+ finals = []
111
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
112
+ for word, pos in seg_cut:
113
+ if pos == "eng":
114
+ initials.append(['EN_WORD'])
115
+ finals.append([word])
116
+ else:
117
+ sub_initials, sub_finals = _get_initials_finals(word)
118
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
119
+ initials.append(sub_initials)
120
+ finals.append(sub_finals)
121
+
122
+ # assert len(sub_initials) == len(sub_finals) == len(word)
123
+ initials = sum(initials, [])
124
+ finals = sum(finals, [])
125
+ #
126
+ for c, v in zip(initials, finals):
127
+ if c == 'EN_WORD':
128
+ tokenized_en = tokenizer.tokenize(v)
129
+ phones_en, tones_en, word2ph_en = g2p_en(text=None, pad_start_end=False, tokenized=tokenized_en)
130
+ # apply offset to tones_en
131
+ tones_en = [t + language_tone_start_map['EN'] for t in tones_en]
132
+ phones_list += phones_en
133
+ tones_list += tones_en
134
+ word2ph += word2ph_en
135
+ else:
136
+ raw_pinyin = c + v
137
+ # NOTE: post process for pypinyin outputs
138
+ # we discriminate i, ii and iii
139
+ if c == v:
140
+ assert c in punctuation
141
+ phone = [c]
142
+ tone = "0"
143
+ word2ph.append(1)
144
+ else:
145
+ v_without_tone = v[:-1]
146
+ tone = v[-1]
147
+
148
+ pinyin = c + v_without_tone
149
+ assert tone in "12345"
150
+
151
+ if c:
152
+ # 多音节
153
+ v_rep_map = {
154
+ "uei": "ui",
155
+ "iou": "iu",
156
+ "uen": "un",
157
+ }
158
+ if v_without_tone in v_rep_map.keys():
159
+ pinyin = c + v_rep_map[v_without_tone]
160
+ else:
161
+ # 单音节
162
+ pinyin_rep_map = {
163
+ "ing": "ying",
164
+ "i": "yi",
165
+ "in": "yin",
166
+ "u": "wu",
167
+ }
168
+ if pinyin in pinyin_rep_map.keys():
169
+ pinyin = pinyin_rep_map[pinyin]
170
+ else:
171
+ single_rep_map = {
172
+ "v": "yu",
173
+ "e": "e",
174
+ "i": "y",
175
+ "u": "w",
176
+ }
177
+ if pinyin[0] in single_rep_map.keys():
178
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
179
+
180
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
181
+ phone = pinyin_to_symbol_map[pinyin].split(" ")
182
+ word2ph.append(len(phone))
183
+
184
+ phones_list += phone
185
+ tones_list += [int(tone)] * len(phone)
186
+ return phones_list, tones_list, word2ph
187
+
188
+
189
+ def text_normalize(text):
190
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
191
+ for number in numbers:
192
+ text = text.replace(number, cn2an.an2cn(number), 1)
193
+ text = replace_punctuation(text)
194
+ return text
195
+
196
+
197
+ def get_bert_feature(text, word2ph, device):
198
+ from . import chinese_bert
199
+ return chinese_bert.get_bert_feature(text, word2ph, model_id='bert-base-multilingual-uncased', device=device)
200
+
201
+ from .chinese import _g2p as _chinese_g2p
202
+ def _g2p_v2(segments):
203
+ spliter = '#$&^!@'
204
+
205
+ phones_list = []
206
+ tones_list = []
207
+ word2ph = []
208
+
209
+ for text in segments:
210
+ assert spliter not in text
211
+ # replace all english words
212
+ text = re.sub('([a-zA-Z\s]+)', lambda x: f'{spliter}{x.group(1)}{spliter}', text)
213
+ texts = text.split(spliter)
214
+ texts = [t for t in texts if len(t) > 0]
215
+
216
+
217
+ for text in texts:
218
+ if re.match('[a-zA-Z\s]+', text):
219
+ # english
220
+ tokenized_en = tokenizer.tokenize(text)
221
+ phones_en, tones_en, word2ph_en = g2p_en(text=None, pad_start_end=False, tokenized=tokenized_en)
222
+ # apply offset to tones_en
223
+ tones_en = [t + language_tone_start_map['EN'] for t in tones_en]
224
+ phones_list += phones_en
225
+ tones_list += tones_en
226
+ word2ph += word2ph_en
227
+ else:
228
+ phones_zh, tones_zh, word2ph_zh = _chinese_g2p([text])
229
+ phones_list += phones_zh
230
+ tones_list += tones_zh
231
+ word2ph += word2ph_zh
232
+ return phones_list, tones_list, word2ph
233
+
234
+
235
+
236
+ if __name__ == "__main__":
237
+ # from text.chinese_bert import get_bert_feature
238
+
239
+ text = "NFT啊!chemistry 但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
240
+ text = '我最近在学习machine learning,希望能够在未来的artificial intelligence领域有所建树。'
241
+ text = '今天下午,我们准备去shopping mall购物,然后晚上去看一场movie。'
242
+ text = '我们现在 also 能够 help 很多公司 use some machine learning 的 algorithms 啊!'
243
+ text = text_normalize(text)
244
+ print(text)
245
+ phones, tones, word2ph = g2p(text, impl='v2')
246
+ bert = get_bert_feature(text, word2ph, device='cuda:0')
247
+ print(phones)
248
+ import pdb; pdb.set_trace()
249
+
250
+
251
+ # # 示例用法
252
+ # text = "这是一个示例文本:,你好!这是一个测试...."
253
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
MeloTTS/melo/text/cleaner.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import chinese, japanese, english, chinese_mix, korean, french, spanish
2
+ from . import cleaned_text_to_sequence
3
+ import copy
4
+
5
+ language_module_map = {"ZH": chinese, "JP": japanese, "EN": english, 'ZH_MIX_EN': chinese_mix, 'KR': korean,
6
+ 'FR': french, 'SP': spanish, 'ES': spanish}
7
+
8
+
9
+ def clean_text(text, language):
10
+ language_module = language_module_map[language]
11
+ norm_text = language_module.text_normalize(text)
12
+ phones, tones, word2ph = language_module.g2p(norm_text)
13
+ return norm_text, phones, tones, word2ph
14
+
15
+
16
+ def clean_text_bert(text, language, device=None):
17
+ language_module = language_module_map[language]
18
+ norm_text = language_module.text_normalize(text)
19
+ phones, tones, word2ph = language_module.g2p(norm_text)
20
+
21
+ word2ph_bak = copy.deepcopy(word2ph)
22
+ for i in range(len(word2ph)):
23
+ word2ph[i] = word2ph[i] * 2
24
+ word2ph[0] += 1
25
+ bert = language_module.get_bert_feature(norm_text, word2ph, device=device)
26
+
27
+ return norm_text, phones, tones, word2ph_bak, bert
28
+
29
+
30
+ def text_to_sequence(text, language):
31
+ norm_text, phones, tones, word2ph = clean_text(text, language)
32
+ return cleaned_text_to_sequence(phones, tones, language)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ pass
MeloTTS/melo/text/cleaner_multiling.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Set of default text cleaners"""
2
+ # TODO: pick the cleaner for languages dynamically
3
+
4
+ import re
5
+
6
+ # Regular expression matching whitespace:
7
+ _whitespace_re = re.compile(r"\s+")
8
+
9
+ rep_map = {
10
+ ":": ",",
11
+ ";": ",",
12
+ ",": ",",
13
+ "。": ".",
14
+ "!": "!",
15
+ "?": "?",
16
+ "\n": ".",
17
+ "·": ",",
18
+ "、": ",",
19
+ "...": ".",
20
+ "…": ".",
21
+ "$": ".",
22
+ "“": "'",
23
+ "”": "'",
24
+ "‘": "'",
25
+ "’": "'",
26
+ "(": "'",
27
+ ")": "'",
28
+ "(": "'",
29
+ ")": "'",
30
+ "《": "'",
31
+ "》": "'",
32
+ "【": "'",
33
+ "】": "'",
34
+ "[": "'",
35
+ "]": "'",
36
+ "—": "",
37
+ "~": "-",
38
+ "~": "-",
39
+ "「": "'",
40
+ "」": "'",
41
+ }
42
+
43
+ def replace_punctuation(text):
44
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
45
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
46
+ return replaced_text
47
+
48
+ def lowercase(text):
49
+ return text.lower()
50
+
51
+
52
+ def collapse_whitespace(text):
53
+ return re.sub(_whitespace_re, " ", text).strip()
54
+
55
+ def remove_punctuation_at_begin(text):
56
+ return re.sub(r'^[,.!?]+', '', text)
57
+
58
+ def remove_aux_symbols(text):
59
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»\']+", "", text)
60
+ return text
61
+
62
+
63
+ def replace_symbols(text, lang="en"):
64
+ """Replace symbols based on the lenguage tag.
65
+
66
+ Args:
67
+ text:
68
+ Input text.
69
+ lang:
70
+ Lenguage identifier. ex: "en", "fr", "pt", "ca".
71
+
72
+ Returns:
73
+ The modified text
74
+ example:
75
+ input args:
76
+ text: "si l'avi cau, diguem-ho"
77
+ lang: "ca"
78
+ Output:
79
+ text: "si lavi cau, diguemho"
80
+ """
81
+ text = text.replace(";", ",")
82
+ text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
83
+ text = text.replace(":", ",")
84
+ if lang == "en":
85
+ text = text.replace("&", " and ")
86
+ elif lang == "fr":
87
+ text = text.replace("&", " et ")
88
+ elif lang == "pt":
89
+ text = text.replace("&", " e ")
90
+ elif lang == "ca":
91
+ text = text.replace("&", " i ")
92
+ text = text.replace("'", "")
93
+ elif lang== "es":
94
+ text=text.replace("&","y")
95
+ text = text.replace("'", "")
96
+ return text
97
+
98
+ def unicleaners(text, cased=False, lang='en'):
99
+ """Basic pipeline for Portuguese text. There is no need to expand abbreviation and
100
+ numbers, phonemizer already does that"""
101
+ if not cased:
102
+ text = lowercase(text)
103
+ text = replace_punctuation(text)
104
+ text = replace_symbols(text, lang=lang)
105
+ text = remove_aux_symbols(text)
106
+ text = remove_punctuation_at_begin(text)
107
+ text = collapse_whitespace(text)
108
+ text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
109
+ return text
110
+
MeloTTS/melo/text/cmudict.rep ADDED
The diff for this file is too large to render. See raw diff
 
MeloTTS/melo/text/cmudict_cache.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b21b20325471934ba92f2e4a5976989e7d920caa32e7a286eacb027d197949
3
+ size 6212655
MeloTTS/melo/text/english.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ import re
4
+ from g2p_en import G2p
5
+
6
+ from . import symbols
7
+
8
+ from .english_utils.abbreviations import expand_abbreviations
9
+ from .english_utils.time_norm import expand_time_english
10
+ from .english_utils.number_norm import normalize_numbers
11
+ from .japanese import distribute_phone
12
+
13
+ from transformers import AutoTokenizer
14
+
15
+ current_file_path = os.path.dirname(__file__)
16
+ CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
17
+ CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
18
+ _g2p = G2p()
19
+
20
+ arpa = {
21
+ "AH0",
22
+ "S",
23
+ "AH1",
24
+ "EY2",
25
+ "AE2",
26
+ "EH0",
27
+ "OW2",
28
+ "UH0",
29
+ "NG",
30
+ "B",
31
+ "G",
32
+ "AY0",
33
+ "M",
34
+ "AA0",
35
+ "F",
36
+ "AO0",
37
+ "ER2",
38
+ "UH1",
39
+ "IY1",
40
+ "AH2",
41
+ "DH",
42
+ "IY0",
43
+ "EY1",
44
+ "IH0",
45
+ "K",
46
+ "N",
47
+ "W",
48
+ "IY2",
49
+ "T",
50
+ "AA1",
51
+ "ER1",
52
+ "EH2",
53
+ "OY0",
54
+ "UH2",
55
+ "UW1",
56
+ "Z",
57
+ "AW2",
58
+ "AW1",
59
+ "V",
60
+ "UW2",
61
+ "AA2",
62
+ "ER",
63
+ "AW0",
64
+ "UW0",
65
+ "R",
66
+ "OW1",
67
+ "EH1",
68
+ "ZH",
69
+ "AE0",
70
+ "IH2",
71
+ "IH",
72
+ "Y",
73
+ "JH",
74
+ "P",
75
+ "AY1",
76
+ "EY0",
77
+ "OY2",
78
+ "TH",
79
+ "HH",
80
+ "D",
81
+ "ER0",
82
+ "CH",
83
+ "AO1",
84
+ "AE1",
85
+ "AO2",
86
+ "OY1",
87
+ "AY2",
88
+ "IH1",
89
+ "OW0",
90
+ "L",
91
+ "SH",
92
+ }
93
+
94
+
95
+ def post_replace_ph(ph):
96
+ rep_map = {
97
+ ":": ",",
98
+ ";": ",",
99
+ ",": ",",
100
+ "。": ".",
101
+ "!": "!",
102
+ "?": "?",
103
+ "\n": ".",
104
+ "·": ",",
105
+ "、": ",",
106
+ "...": "…",
107
+ "v": "V",
108
+ }
109
+ if ph in rep_map.keys():
110
+ ph = rep_map[ph]
111
+ if ph in symbols:
112
+ return ph
113
+ if ph not in symbols:
114
+ ph = "UNK"
115
+ return ph
116
+
117
+
118
+ def read_dict():
119
+ g2p_dict = {}
120
+ start_line = 49
121
+ with open(CMU_DICT_PATH) as f:
122
+ line = f.readline()
123
+ line_index = 1
124
+ while line:
125
+ if line_index >= start_line:
126
+ line = line.strip()
127
+ word_split = line.split(" ")
128
+ word = word_split[0]
129
+
130
+ syllable_split = word_split[1].split(" - ")
131
+ g2p_dict[word] = []
132
+ for syllable in syllable_split:
133
+ phone_split = syllable.split(" ")
134
+ g2p_dict[word].append(phone_split)
135
+
136
+ line_index = line_index + 1
137
+ line = f.readline()
138
+
139
+ return g2p_dict
140
+
141
+
142
+ def cache_dict(g2p_dict, file_path):
143
+ with open(file_path, "wb") as pickle_file:
144
+ pickle.dump(g2p_dict, pickle_file)
145
+
146
+
147
+ def get_dict():
148
+ if os.path.exists(CACHE_PATH):
149
+ with open(CACHE_PATH, "rb") as pickle_file:
150
+ g2p_dict = pickle.load(pickle_file)
151
+ else:
152
+ g2p_dict = read_dict()
153
+ cache_dict(g2p_dict, CACHE_PATH)
154
+
155
+ return g2p_dict
156
+
157
+
158
+ eng_dict = get_dict()
159
+
160
+
161
+ def refine_ph(phn):
162
+ tone = 0
163
+ if re.search(r"\d$", phn):
164
+ tone = int(phn[-1]) + 1
165
+ phn = phn[:-1]
166
+ return phn.lower(), tone
167
+
168
+
169
+ def refine_syllables(syllables):
170
+ tones = []
171
+ phonemes = []
172
+ for phn_list in syllables:
173
+ for i in range(len(phn_list)):
174
+ phn = phn_list[i]
175
+ phn, tone = refine_ph(phn)
176
+ phonemes.append(phn)
177
+ tones.append(tone)
178
+ return phonemes, tones
179
+
180
+
181
+ def text_normalize(text):
182
+ text = text.lower()
183
+ text = expand_time_english(text)
184
+ text = normalize_numbers(text)
185
+ text = expand_abbreviations(text)
186
+ return text
187
+
188
+ model_id = 'bert-base-uncased'
189
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
190
+ def g2p_old(text):
191
+ tokenized = tokenizer.tokenize(text)
192
+ # import pdb; pdb.set_trace()
193
+ phones = []
194
+ tones = []
195
+ words = re.split(r"([,;.\-\?\!\s+])", text)
196
+ for w in words:
197
+ if w.upper() in eng_dict:
198
+ phns, tns = refine_syllables(eng_dict[w.upper()])
199
+ phones += phns
200
+ tones += tns
201
+ else:
202
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
203
+ for ph in phone_list:
204
+ if ph in arpa:
205
+ ph, tn = refine_ph(ph)
206
+ phones.append(ph)
207
+ tones.append(tn)
208
+ else:
209
+ phones.append(ph)
210
+ tones.append(0)
211
+ # todo: implement word2ph
212
+ word2ph = [1 for i in phones]
213
+
214
+ phones = [post_replace_ph(i) for i in phones]
215
+ return phones, tones, word2ph
216
+
217
+ def g2p(text, pad_start_end=True, tokenized=None):
218
+ if tokenized is None:
219
+ tokenized = tokenizer.tokenize(text)
220
+ # import pdb; pdb.set_trace()
221
+ phs = []
222
+ ph_groups = []
223
+ for t in tokenized:
224
+ if not t.startswith("#"):
225
+ ph_groups.append([t])
226
+ else:
227
+ ph_groups[-1].append(t.replace("#", ""))
228
+
229
+ phones = []
230
+ tones = []
231
+ word2ph = []
232
+ for group in ph_groups:
233
+ w = "".join(group)
234
+ phone_len = 0
235
+ word_len = len(group)
236
+ if w.upper() in eng_dict:
237
+ phns, tns = refine_syllables(eng_dict[w.upper()])
238
+ phones += phns
239
+ tones += tns
240
+ phone_len += len(phns)
241
+ else:
242
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
243
+ for ph in phone_list:
244
+ if ph in arpa:
245
+ ph, tn = refine_ph(ph)
246
+ phones.append(ph)
247
+ tones.append(tn)
248
+ else:
249
+ phones.append(ph)
250
+ tones.append(0)
251
+ phone_len += 1
252
+ aaa = distribute_phone(phone_len, word_len)
253
+ word2ph += aaa
254
+ phones = [post_replace_ph(i) for i in phones]
255
+
256
+ if pad_start_end:
257
+ phones = ["_"] + phones + ["_"]
258
+ tones = [0] + tones + [0]
259
+ word2ph = [1] + word2ph + [1]
260
+ return phones, tones, word2ph
261
+
262
+ def get_bert_feature(text, word2ph, device=None):
263
+ from text import english_bert
264
+
265
+ return english_bert.get_bert_feature(text, word2ph, device=device)
266
+
267
+ if __name__ == "__main__":
268
+ # print(get_dict())
269
+ # print(eng_word_to_phoneme("hello"))
270
+ from text.english_bert import get_bert_feature
271
+ text = "In this paper, we propose 1 DSPGAN, a N-F-T GAN-based universal vocoder."
272
+ text = text_normalize(text)
273
+ phones, tones, word2ph = g2p(text)
274
+ import pdb; pdb.set_trace()
275
+ bert = get_bert_feature(text, word2ph)
276
+
277
+ print(phones, tones, word2ph, bert.shape)
278
+
279
+ # all_phones = set()
280
+ # for k, syllables in eng_dict.items():
281
+ # for group in syllables:
282
+ # for ph in group:
283
+ # all_phones.add(ph)
284
+ # print(all_phones)
MeloTTS/melo/text/english_bert.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ import sys
4
+
5
+ model_id = 'bert-base-uncased'
6
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
7
+ model = None
8
+
9
+ def get_bert_feature(text, word2ph, device=None):
10
+ global model
11
+ if (
12
+ sys.platform == "darwin"
13
+ and torch.backends.mps.is_available()
14
+ and device == "cpu"
15
+ ):
16
+ device = "mps"
17
+ if not device:
18
+ device = "cuda"
19
+ if model is None:
20
+ model = AutoModelForMaskedLM.from_pretrained(model_id).to(
21
+ device
22
+ )
23
+ with torch.no_grad():
24
+ inputs = tokenizer(text, return_tensors="pt")
25
+ for i in inputs:
26
+ inputs[i] = inputs[i].to(device)
27
+ res = model(**inputs, output_hidden_states=True)
28
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
29
+
30
+ assert inputs["input_ids"].shape[-1] == len(word2ph)
31
+ word2phone = word2ph
32
+ phone_level_feature = []
33
+ for i in range(len(word2phone)):
34
+ repeat_feature = res[i].repeat(word2phone[i], 1)
35
+ phone_level_feature.append(repeat_feature)
36
+
37
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
38
+
39
+ return phone_level_feature.T
MeloTTS/melo/text/english_utils/__init__.py ADDED
File without changes
MeloTTS/melo/text/english_utils/abbreviations.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ # List of (regular expression, replacement) pairs for abbreviations in english:
4
+ abbreviations_en = [
5
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
6
+ for x in [
7
+ ("mrs", "misess"),
8
+ ("mr", "mister"),
9
+ ("dr", "doctor"),
10
+ ("st", "saint"),
11
+ ("co", "company"),
12
+ ("jr", "junior"),
13
+ ("maj", "major"),
14
+ ("gen", "general"),
15
+ ("drs", "doctors"),
16
+ ("rev", "reverend"),
17
+ ("lt", "lieutenant"),
18
+ ("hon", "honorable"),
19
+ ("sgt", "sergeant"),
20
+ ("capt", "captain"),
21
+ ("esq", "esquire"),
22
+ ("ltd", "limited"),
23
+ ("col", "colonel"),
24
+ ("ft", "fort"),
25
+ ]
26
+ ]
27
+
28
+ def expand_abbreviations(text, lang="en"):
29
+ if lang == "en":
30
+ _abbreviations = abbreviations_en
31
+ else:
32
+ raise NotImplementedError()
33
+ for regex, replacement in _abbreviations:
34
+ text = re.sub(regex, replacement, text)
35
+ return text
MeloTTS/melo/text/english_utils/number_norm.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+ from typing import Dict
5
+
6
+ import inflect
7
+
8
+ _inflect = inflect.engine()
9
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
10
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
11
+ _currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)")
12
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13
+ _number_re = re.compile(r"-?[0-9]+")
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(",", "")
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace(".", " point ")
22
+
23
+
24
+ def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
25
+ parts = value.replace(",", "").split(".")
26
+ if len(parts) > 2:
27
+ return f"{value} {inflection[2]}" # Unexpected format
28
+ text = []
29
+ integer = int(parts[0]) if parts[0] else 0
30
+ if integer > 0:
31
+ integer_unit = inflection.get(integer, inflection[2])
32
+ text.append(f"{integer} {integer_unit}")
33
+ fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0
34
+ if fraction > 0:
35
+ fraction_unit = inflection.get(fraction / 100, inflection[0.02])
36
+ text.append(f"{fraction} {fraction_unit}")
37
+ if len(text) == 0:
38
+ return f"zero {inflection[2]}"
39
+ return " ".join(text)
40
+
41
+
42
+ def _expand_currency(m: "re.Match") -> str:
43
+ currencies = {
44
+ "$": {
45
+ 0.01: "cent",
46
+ 0.02: "cents",
47
+ 1: "dollar",
48
+ 2: "dollars",
49
+ },
50
+ "€": {
51
+ 0.01: "cent",
52
+ 0.02: "cents",
53
+ 1: "euro",
54
+ 2: "euros",
55
+ },
56
+ "£": {
57
+ 0.01: "penny",
58
+ 0.02: "pence",
59
+ 1: "pound sterling",
60
+ 2: "pounds sterling",
61
+ },
62
+ "¥": {
63
+ # TODO rin
64
+ 0.02: "sen",
65
+ 2: "yen",
66
+ },
67
+ }
68
+ unit = m.group(1)
69
+ currency = currencies[unit]
70
+ value = m.group(2)
71
+ return __expand_currency(value, currency)
72
+
73
+
74
+ def _expand_ordinal(m):
75
+ return _inflect.number_to_words(m.group(0))
76
+
77
+
78
+ def _expand_number(m):
79
+ num = int(m.group(0))
80
+ if 1000 < num < 3000:
81
+ if num == 2000:
82
+ return "two thousand"
83
+ if 2000 < num < 2010:
84
+ return "two thousand " + _inflect.number_to_words(num % 100)
85
+ if num % 100 == 0:
86
+ return _inflect.number_to_words(num // 100) + " hundred"
87
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
88
+ return _inflect.number_to_words(num, andword="")
89
+
90
+
91
+ def normalize_numbers(text):
92
+ text = re.sub(_comma_number_re, _remove_commas, text)
93
+ text = re.sub(_currency_re, _expand_currency, text)
94
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
95
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
96
+ text = re.sub(_number_re, _expand_number, text)
97
+ return text
MeloTTS/melo/text/english_utils/time_norm.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import inflect
4
+
5
+ _inflect = inflect.engine()
6
+
7
+ _time_re = re.compile(
8
+ r"""\b
9
+ ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours
10
+ :
11
+ ([0-5][0-9]) # minutes
12
+ \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm
13
+ \b""",
14
+ re.IGNORECASE | re.X,
15
+ )
16
+
17
+
18
+ def _expand_num(n: int) -> str:
19
+ return _inflect.number_to_words(n)
20
+
21
+
22
+ def _expand_time_english(match: "re.Match") -> str:
23
+ hour = int(match.group(1))
24
+ past_noon = hour >= 12
25
+ time = []
26
+ if hour > 12:
27
+ hour -= 12
28
+ elif hour == 0:
29
+ hour = 12
30
+ past_noon = True
31
+ time.append(_expand_num(hour))
32
+
33
+ minute = int(match.group(6))
34
+ if minute > 0:
35
+ if minute < 10:
36
+ time.append("oh")
37
+ time.append(_expand_num(minute))
38
+ am_pm = match.group(7)
39
+ if am_pm is None:
40
+ time.append("p m" if past_noon else "a m")
41
+ else:
42
+ time.extend(list(am_pm.replace(".", "")))
43
+ return " ".join(time)
44
+
45
+
46
+ def expand_time_english(text: str) -> str:
47
+ return re.sub(_time_re, _expand_time_english, text)
MeloTTS/melo/text/es_phonemizer/__init__.py ADDED
File without changes
MeloTTS/melo/text/es_phonemizer/base.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Tuple
3
+
4
+ from .punctuation import Punctuation
5
+
6
+
7
+ class BasePhonemizer(abc.ABC):
8
+ """Base phonemizer class
9
+
10
+ Phonemization follows the following steps:
11
+ 1. Preprocessing:
12
+ - remove empty lines
13
+ - remove punctuation
14
+ - keep track of punctuation marks
15
+
16
+ 2. Phonemization:
17
+ - convert text to phonemes
18
+
19
+ 3. Postprocessing:
20
+ - join phonemes
21
+ - restore punctuation marks
22
+
23
+ Args:
24
+ language (str):
25
+ Language used by the phonemizer.
26
+
27
+ punctuations (List[str]):
28
+ List of punctuation marks to be preserved.
29
+
30
+ keep_puncs (bool):
31
+ Whether to preserve punctuation marks or not.
32
+ """
33
+
34
+ def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
35
+ # ensure the backend is installed on the system
36
+ if not self.is_available():
37
+ raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
38
+
39
+ # ensure the backend support the requested language
40
+ self._language = self._init_language(language)
41
+
42
+ # setup punctuation processing
43
+ self._keep_puncs = keep_puncs
44
+ self._punctuator = Punctuation(punctuations)
45
+
46
+ def _init_language(self, language):
47
+ """Language initialization
48
+
49
+ This method may be overloaded in child classes (see Segments backend)
50
+
51
+ """
52
+ if not self.is_supported_language(language):
53
+ raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
54
+ return language
55
+
56
+ @property
57
+ def language(self):
58
+ """The language code configured to be used for phonemization"""
59
+ return self._language
60
+
61
+ @staticmethod
62
+ @abc.abstractmethod
63
+ def name():
64
+ """The name of the backend"""
65
+ ...
66
+
67
+ @classmethod
68
+ @abc.abstractmethod
69
+ def is_available(cls):
70
+ """Returns True if the backend is installed, False otherwise"""
71
+ ...
72
+
73
+ @classmethod
74
+ @abc.abstractmethod
75
+ def version(cls):
76
+ """Return the backend version as a tuple (major, minor, patch)"""
77
+ ...
78
+
79
+ @staticmethod
80
+ @abc.abstractmethod
81
+ def supported_languages():
82
+ """Return a dict of language codes -> name supported by the backend"""
83
+ ...
84
+
85
+ def is_supported_language(self, language):
86
+ """Returns True if `language` is supported by the backend"""
87
+ return language in self.supported_languages()
88
+
89
+ @abc.abstractmethod
90
+ def _phonemize(self, text, separator):
91
+ """The main phonemization method"""
92
+
93
+ def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
94
+ """Preprocess the text before phonemization
95
+
96
+ 1. remove spaces
97
+ 2. remove punctuation
98
+
99
+ Override this if you need a different behaviour
100
+ """
101
+ text = text.strip()
102
+ if self._keep_puncs:
103
+ # a tuple (text, punctuation marks)
104
+ return self._punctuator.strip_to_restore(text)
105
+ return [self._punctuator.strip(text)], []
106
+
107
+ def _phonemize_postprocess(self, phonemized, punctuations) -> str:
108
+ """Postprocess the raw phonemized output
109
+
110
+ Override this if you need a different behaviour
111
+ """
112
+ if self._keep_puncs:
113
+ return self._punctuator.restore(phonemized, punctuations)[0]
114
+ return phonemized[0]
115
+
116
+ def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument
117
+ """Returns the `text` phonemized for the given language
118
+
119
+ Args:
120
+ text (str):
121
+ Text to be phonemized.
122
+
123
+ separator (str):
124
+ string separator used between phonemes. Default to '_'.
125
+
126
+ Returns:
127
+ (str): Phonemized text
128
+ """
129
+ text, punctuations = self._phonemize_preprocess(text)
130
+ phonemized = []
131
+ for t in text:
132
+ p = self._phonemize(t, separator)
133
+ phonemized.append(p)
134
+ phonemized = self._phonemize_postprocess(phonemized, punctuations)
135
+ return phonemized
136
+
137
+ def print_logs(self, level: int = 0):
138
+ indent = "\t" * level
139
+ print(f"{indent}| > phoneme language: {self.language}")
140
+ print(f"{indent}| > phoneme backend: {self.name()}")
MeloTTS/melo/text/es_phonemizer/cleaner.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Set of default text cleaners"""
2
+ # TODO: pick the cleaner for languages dynamically
3
+
4
+ import re
5
+
6
+ # Regular expression matching whitespace:
7
+ _whitespace_re = re.compile(r"\s+")
8
+
9
+ rep_map = {
10
+ ":": ",",
11
+ ";": ",",
12
+ ",": ",",
13
+ "。": ".",
14
+ "!": "!",
15
+ "?": "?",
16
+ "\n": ".",
17
+ "·": ",",
18
+ "、": ",",
19
+ "...": ".",
20
+ "…": ".",
21
+ "$": ".",
22
+ "“": "'",
23
+ "”": "'",
24
+ "‘": "'",
25
+ "’": "'",
26
+ "(": "'",
27
+ ")": "'",
28
+ "(": "'",
29
+ ")": "'",
30
+ "《": "'",
31
+ "》": "'",
32
+ "【": "'",
33
+ "】": "'",
34
+ "[": "'",
35
+ "]": "'",
36
+ "—": "",
37
+ "~": "-",
38
+ "~": "-",
39
+ "「": "'",
40
+ "」": "'",
41
+ }
42
+
43
+ def replace_punctuation(text):
44
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
45
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
46
+ return replaced_text
47
+
48
+ def lowercase(text):
49
+ return text.lower()
50
+
51
+
52
+ def collapse_whitespace(text):
53
+ return re.sub(_whitespace_re, " ", text).strip()
54
+
55
+ def remove_punctuation_at_begin(text):
56
+ return re.sub(r'^[,.!?]+', '', text)
57
+
58
+ def remove_aux_symbols(text):
59
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»\']+", "", text)
60
+ return text
61
+
62
+
63
+ def replace_symbols(text, lang="en"):
64
+ """Replace symbols based on the lenguage tag.
65
+
66
+ Args:
67
+ text:
68
+ Input text.
69
+ lang:
70
+ Lenguage identifier. ex: "en", "fr", "pt", "ca".
71
+
72
+ Returns:
73
+ The modified text
74
+ example:
75
+ input args:
76
+ text: "si l'avi cau, diguem-ho"
77
+ lang: "ca"
78
+ Output:
79
+ text: "si lavi cau, diguemho"
80
+ """
81
+ text = text.replace(";", ",")
82
+ text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
83
+ text = text.replace(":", ",")
84
+ if lang == "en":
85
+ text = text.replace("&", " and ")
86
+ elif lang == "fr":
87
+ text = text.replace("&", " et ")
88
+ elif lang == "pt":
89
+ text = text.replace("&", " e ")
90
+ elif lang == "ca":
91
+ text = text.replace("&", " i ")
92
+ text = text.replace("'", "")
93
+ elif lang== "es":
94
+ text=text.replace("&","y")
95
+ text = text.replace("'", "")
96
+ return text
97
+
98
+ def spanish_cleaners(text):
99
+ """Basic pipeline for Portuguese text. There is no need to expand abbreviation and
100
+ numbers, phonemizer already does that"""
101
+ text = lowercase(text)
102
+ text = replace_symbols(text, lang="es")
103
+ text = replace_punctuation(text)
104
+ text = remove_aux_symbols(text)
105
+ text = remove_punctuation_at_begin(text)
106
+ text = collapse_whitespace(text)
107
+ text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
108
+ return text
109
+
MeloTTS/melo/text/es_phonemizer/es_symbols.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "symbols": [
3
+ "_",
4
+ ",",
5
+ ".",
6
+ "!",
7
+ "?",
8
+ "-",
9
+ "~",
10
+ "\u2026",
11
+ "N",
12
+ "Q",
13
+ "a",
14
+ "b",
15
+ "d",
16
+ "e",
17
+ "f",
18
+ "g",
19
+ "h",
20
+ "i",
21
+ "j",
22
+ "k",
23
+ "l",
24
+ "m",
25
+ "n",
26
+ "o",
27
+ "p",
28
+ "s",
29
+ "t",
30
+ "u",
31
+ "v",
32
+ "w",
33
+ "x",
34
+ "y",
35
+ "z",
36
+ "\u0251",
37
+ "\u00e6",
38
+ "\u0283",
39
+ "\u0291",
40
+ "\u00e7",
41
+ "\u026f",
42
+ "\u026a",
43
+ "\u0254",
44
+ "\u025b",
45
+ "\u0279",
46
+ "\u00f0",
47
+ "\u0259",
48
+ "\u026b",
49
+ "\u0265",
50
+ "\u0278",
51
+ "\u028a",
52
+ "\u027e",
53
+ "\u0292",
54
+ "\u03b8",
55
+ "\u03b2",
56
+ "\u014b",
57
+ "\u0266",
58
+ "\u207c",
59
+ "\u02b0",
60
+ "`",
61
+ "^",
62
+ "#",
63
+ "*",
64
+ "=",
65
+ "\u02c8",
66
+ "\u02cc",
67
+ "\u2192",
68
+ "\u2193",
69
+ "\u2191",
70
+ " ",
71
+ "\u0263",
72
+ "\u0261",
73
+ "r",
74
+ "\u0272",
75
+ "\u029d",
76
+ "\u028e",
77
+ "\u02d0"
78
+ ]
79
+ }
MeloTTS/melo/text/es_phonemizer/es_symbols.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ _,.!?-~…NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ɡrɲʝɣʎː—¿¡
MeloTTS/melo/text/es_phonemizer/es_symbols_v2.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "symbols": [
3
+ "_",
4
+ ",",
5
+ ".",
6
+ "!",
7
+ "?",
8
+ "-",
9
+ "~",
10
+ "\u2026",
11
+ "N",
12
+ "Q",
13
+ "a",
14
+ "b",
15
+ "d",
16
+ "e",
17
+ "f",
18
+ "g",
19
+ "h",
20
+ "i",
21
+ "j",
22
+ "k",
23
+ "l",
24
+ "m",
25
+ "n",
26
+ "o",
27
+ "p",
28
+ "s",
29
+ "t",
30
+ "u",
31
+ "v",
32
+ "w",
33
+ "x",
34
+ "y",
35
+ "z",
36
+ "\u0251",
37
+ "\u00e6",
38
+ "\u0283",
39
+ "\u0291",
40
+ "\u00e7",
41
+ "\u026f",
42
+ "\u026a",
43
+ "\u0254",
44
+ "\u025b",
45
+ "\u0279",
46
+ "\u00f0",
47
+ "\u0259",
48
+ "\u026b",
49
+ "\u0265",
50
+ "\u0278",
51
+ "\u028a",
52
+ "\u027e",
53
+ "\u0292",
54
+ "\u03b8",
55
+ "\u03b2",
56
+ "\u014b",
57
+ "\u0266",
58
+ "\u207c",
59
+ "\u02b0",
60
+ "`",
61
+ "^",
62
+ "#",
63
+ "*",
64
+ "=",
65
+ "\u02c8",
66
+ "\u02cc",
67
+ "\u2192",
68
+ "\u2193",
69
+ "\u2191",
70
+ " ",
71
+ "\u0261",
72
+ "r",
73
+ "\u0272",
74
+ "\u029d",
75
+ "\u0263",
76
+ "\u028e",
77
+ "\u02d0",
78
+
79
+ "\u2014",
80
+ "\u00bf",
81
+ "\u00a1"
82
+ ]
83
+ }