LucaOne / alphabet_atom.py
Yuanfei's picture
Upload LucaGPLM
96c0ca2 verified
raw
history blame contribute delete
4.85 kB
#!/usr/bin/env python
# encoding: utf-8
from rdkit import Chem
from rdkit.Chem import AllChem
from typing import Sequence, List
atom_standard_toks = ['C', 'N', 'O', 'S', 'H', 'Cl', 'F', 'Br', 'I',
'Si', 'P', 'B', 'Na', 'K', 'Al', 'Ca', 'Sn', 'As',
'Hg', 'Fe', 'Zn', 'Cr', 'Se', 'Gd', 'Au', 'Li'
]
atom_prepend_toks = ['[PAD]', '[UNK]', '[CLS]']
atom_append_toks = ['[SEP]', '[MASK]']
class AlphabetAtom(object):
def __init__(
self,
standard_toks: Sequence[str] = atom_standard_toks,
prepend_toks: Sequence[str] = atom_prepend_toks,
append_toks: Sequence[str] = atom_append_toks,
prepend_bos: bool = True,
append_eos: bool = True
):
self.standard_toks = list(standard_toks)
self.prepend_toks = list(prepend_toks)
self.append_toks = list(append_toks)
self.prepend_bos = prepend_bos
self.append_eos = append_eos
self.all_toks = list(self.prepend_toks)
self.all_toks.extend(self.append_toks)
self.all_toks.extend(self.standard_toks)
self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
self.unk_idx = self.tok_to_idx["[UNK]"]
self.padding_idx = self.get_idx("[PAD]")
self.pad_idx = self.get_idx("[PAD]")
self.pad_token_id = self.padding_idx
self.cls_idx = self.get_idx("[CLS]")
self.mask_idx = self.get_idx("[MASK]")
self.eos_idx = self.get_idx("[SEP]")
self.all_special_tokens = prepend_toks + append_toks
self.all_special_token_idx_list = [self.tok_to_idx[v] for v in self.all_special_tokens]
self.unique_no_split_tokens = self.all_toks
self.vocab_size = self.__len__()
def __len__(self):
return len(self.all_toks)
def get_idx(self, tok):
return self.tok_to_idx.get(tok, self.unk_idx)
def get_tok(self, ind):
return self.all_toks[ind]
def to_dict(self):
return self.tok_to_idx.copy()
def get_batch_converter(self, task_level_type, label_size, output_mode, no_position_embeddings,
no_token_type_embeddings, truncation_seq_length: int = None, ignore_index: int = -100, mlm_probability=0.15):
'''
return BatchConverter(
task_level_type,
label_size,
output_mode,
seq_subword=False,
seq_tokenizer=self,
no_position_embeddings=no_position_embeddings,
no_token_type_embeddings=no_token_type_embeddings,
truncation_seq_length=truncation_seq_length,
truncation_matrix_length=truncation_seq_length,
ignore_index=ignore_index,
mlm_probability=mlm_probability,
prepend_bos=self.prepend_bos,
append_eos=self.append_eos)
'''
pass
@classmethod
def smiles_2_atom_seq(cls, smi):
mol = Chem.MolFromSmiles(smi)
mol = AllChem.AddHs(mol)
atoms = [atom.GetSymbol() for atom in mol.GetAtoms()] # after add H
return atoms
@classmethod
def from_predefined(cls, name: str = "atom_v1"):
if name.lower() == "atom_v1":
standard_toks = atom_standard_toks
else:
raise Exception("Not support tokenizer name: %s" % name)
prepend_toks = atom_prepend_toks
append_toks = atom_append_toks
prepend_bos = True
append_eos = True
return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos)
@classmethod
def from_pretrained(cls, dir_path):
import os, pickle
return pickle.load(open(os.path.join(dir_path, "alphabet_atom.pkl"), "rb"))
def save_pretrained(self, save_dir):
import os, pickle
with open(os.path.join(save_dir, "alphabet_atom.pkl"), 'wb') as outp:
pickle.dump(self, outp, pickle.HIGHEST_PROTOCOL)
def tokenize(self, smi, prepend_bos, append_eos) -> List[str]:
seq = AlphabetAtom.smiles_2_atom_seq(smi)
if prepend_bos:
seq = [self.get_tok(self.cls_idx)] + seq
if append_eos:
seq = seq + [self.get_tok(self.eos_idx)]
return seq
def encode(self, atom_list, prepend_bos, append_eos):
idx_list = [self.get_idx(tok) for tok in atom_list]
if prepend_bos:
idx_list = [self.cls_idx] + idx_list
if append_eos:
idx_list = idx_list + [self.eos_idx]
return idx_list
def encode_smi(self, smi, prepend_bos, append_eos):
atom_list = self.smiles_2_atom_seq(smi)
return self.encode(atom_list, prepend_bos, append_eos)