|
|
|
|
|
|
|
from rdkit import Chem
|
|
from rdkit.Chem import AllChem
|
|
from typing import Sequence, List
|
|
|
|
atom_standard_toks = ['C', 'N', 'O', 'S', 'H', 'Cl', 'F', 'Br', 'I',
|
|
'Si', 'P', 'B', 'Na', 'K', 'Al', 'Ca', 'Sn', 'As',
|
|
'Hg', 'Fe', 'Zn', 'Cr', 'Se', 'Gd', 'Au', 'Li'
|
|
]
|
|
|
|
atom_prepend_toks = ['[PAD]', '[UNK]', '[CLS]']
|
|
|
|
atom_append_toks = ['[SEP]', '[MASK]']
|
|
|
|
|
|
class AlphabetAtom(object):
|
|
def __init__(
|
|
self,
|
|
standard_toks: Sequence[str] = atom_standard_toks,
|
|
prepend_toks: Sequence[str] = atom_prepend_toks,
|
|
append_toks: Sequence[str] = atom_append_toks,
|
|
prepend_bos: bool = True,
|
|
append_eos: bool = True
|
|
):
|
|
self.standard_toks = list(standard_toks)
|
|
self.prepend_toks = list(prepend_toks)
|
|
self.append_toks = list(append_toks)
|
|
self.prepend_bos = prepend_bos
|
|
self.append_eos = append_eos
|
|
|
|
self.all_toks = list(self.prepend_toks)
|
|
self.all_toks.extend(self.append_toks)
|
|
self.all_toks.extend(self.standard_toks)
|
|
|
|
self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
|
|
|
|
self.unk_idx = self.tok_to_idx["[UNK]"]
|
|
self.padding_idx = self.get_idx("[PAD]")
|
|
self.pad_idx = self.get_idx("[PAD]")
|
|
self.pad_token_id = self.padding_idx
|
|
self.cls_idx = self.get_idx("[CLS]")
|
|
self.mask_idx = self.get_idx("[MASK]")
|
|
self.eos_idx = self.get_idx("[SEP]")
|
|
self.all_special_tokens = prepend_toks + append_toks
|
|
self.all_special_token_idx_list = [self.tok_to_idx[v] for v in self.all_special_tokens]
|
|
self.unique_no_split_tokens = self.all_toks
|
|
self.vocab_size = self.__len__()
|
|
|
|
def __len__(self):
|
|
return len(self.all_toks)
|
|
|
|
def get_idx(self, tok):
|
|
return self.tok_to_idx.get(tok, self.unk_idx)
|
|
|
|
def get_tok(self, ind):
|
|
return self.all_toks[ind]
|
|
|
|
def to_dict(self):
|
|
return self.tok_to_idx.copy()
|
|
|
|
def get_batch_converter(self, task_level_type, label_size, output_mode, no_position_embeddings,
|
|
no_token_type_embeddings, truncation_seq_length: int = None, ignore_index: int = -100, mlm_probability=0.15):
|
|
'''
|
|
return BatchConverter(
|
|
task_level_type,
|
|
label_size,
|
|
output_mode,
|
|
seq_subword=False,
|
|
seq_tokenizer=self,
|
|
no_position_embeddings=no_position_embeddings,
|
|
no_token_type_embeddings=no_token_type_embeddings,
|
|
truncation_seq_length=truncation_seq_length,
|
|
truncation_matrix_length=truncation_seq_length,
|
|
ignore_index=ignore_index,
|
|
mlm_probability=mlm_probability,
|
|
prepend_bos=self.prepend_bos,
|
|
append_eos=self.append_eos)
|
|
'''
|
|
pass
|
|
|
|
@classmethod
|
|
def smiles_2_atom_seq(cls, smi):
|
|
mol = Chem.MolFromSmiles(smi)
|
|
mol = AllChem.AddHs(mol)
|
|
atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]
|
|
return atoms
|
|
|
|
@classmethod
|
|
def from_predefined(cls, name: str = "atom_v1"):
|
|
if name.lower() == "atom_v1":
|
|
standard_toks = atom_standard_toks
|
|
else:
|
|
raise Exception("Not support tokenizer name: %s" % name)
|
|
|
|
prepend_toks = atom_prepend_toks
|
|
append_toks = atom_append_toks
|
|
prepend_bos = True
|
|
append_eos = True
|
|
|
|
return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, dir_path):
|
|
import os, pickle
|
|
return pickle.load(open(os.path.join(dir_path, "alphabet_atom.pkl"), "rb"))
|
|
|
|
def save_pretrained(self, save_dir):
|
|
import os, pickle
|
|
with open(os.path.join(save_dir, "alphabet_atom.pkl"), 'wb') as outp:
|
|
pickle.dump(self, outp, pickle.HIGHEST_PROTOCOL)
|
|
|
|
def tokenize(self, smi, prepend_bos, append_eos) -> List[str]:
|
|
seq = AlphabetAtom.smiles_2_atom_seq(smi)
|
|
if prepend_bos:
|
|
seq = [self.get_tok(self.cls_idx)] + seq
|
|
if append_eos:
|
|
seq = seq + [self.get_tok(self.eos_idx)]
|
|
return seq
|
|
|
|
def encode(self, atom_list, prepend_bos, append_eos):
|
|
idx_list = [self.get_idx(tok) for tok in atom_list]
|
|
if prepend_bos:
|
|
idx_list = [self.cls_idx] + idx_list
|
|
if append_eos:
|
|
idx_list = idx_list + [self.eos_idx]
|
|
return idx_list
|
|
|
|
def encode_smi(self, smi, prepend_bos, append_eos):
|
|
atom_list = self.smiles_2_atom_seq(smi)
|
|
return self.encode(atom_list, prepend_bos, append_eos)
|
|
|