File size: 2,264 Bytes
497a13e
 
02b6e2c
 
 
 
 
 
 
 
 
 
 
497a13e
02b6e2c
 
324991a
ed4ea4b
 
da9723d
ed4ea4b
0529d1c
ed4ea4b
 
324991a
9d3b149
324991a
80507f7
 
 
 
 
 
 
324991a
80507f7
324991a
80507f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324991a
80507f7
 
324991a
80507f7
 
324991a
80507f7
 
324991a
80507f7
 
324991a
80507f7
 
 
324991a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
---
license: mit
language:
- en
library_name: transformers
tags:
- esm-2
- protein
- token classification
- biology
- esm
- rna
- binding site
---

# ESM-2 for RNA Binding Site Prediction

A small RNA binding site predictor trained on dataset "S1" from [Data of protein-RNA binding sites](https://www.sciencedirect.com/science/article/pii/S2352340916308022#s0035) 
using [facebook/esm2_t6_8M_UR50D](https://huggingface.co/facebook/esm2_t6_8M_UR50D). 
The dataset can also be found on Hugging Face [here](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites).

This model only has a validation loss of `0.12738210861297214`.  

To use, try running:

```python
import torch
from transformers import AutoTokenizer, EsmForTokenClassification

# Define the class mapping
class_mapping = {
    0: 'Not Binding Site',
    1: 'Binding Site',
}

# Load the trained model and tokenizer
model = EsmForTokenClassification.from_pretrained("AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Define the new sequences
new_sequences = [
    'VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK',
    'SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF',
    # ... add more sequences here ...
]

# Iterate over the new sequences
for seq in new_sequences:
    # Convert sequence to input IDs
    inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt")["input_ids"]

    # Apply the model to get the logits
    with torch.no_grad():
        outputs = model(inputs)

    # Get the predictions by picking the label (class) with the highest logit
    predictions = torch.argmax(outputs.logits, dim=-1)

    # Convert the tensor to a list of integers
    prediction_list = predictions.tolist()[0]

    # Convert the predicted class indices to class names
    predicted_labels = [class_mapping[pred] for pred in prediction_list]

    # Create a list that matches each amino acid in the sequence to its predicted class label
    residue_to_label = list(zip(list(seq), predicted_labels))

    # Print out the list
    for i, (residue, predicted_label) in enumerate(residue_to_label):
        print(f"Position {i+1} - {residue}: {predicted_label}")
```