Update README.md
Browse files
README.md
CHANGED
@@ -46,11 +46,50 @@ The following hyperparameters were used during training:
|
|
46 |
|
47 |
### Training results
|
48 |
|
49 |
-
|
50 |
-
|
51 |
### Framework versions
|
52 |
|
53 |
- PEFT 0.14.0
|
54 |
- Transformers 4.47.0
|
55 |
- Pytorch 2.2.1+cu121
|
56 |
-
- Tokenizers 0.21.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
### Training results
|
48 |
|
|
|
|
|
49 |
### Framework versions
|
50 |
|
51 |
- PEFT 0.14.0
|
52 |
- Transformers 4.47.0
|
53 |
- Pytorch 2.2.1+cu121
|
54 |
+
- Tokenizers 0.21.0
|
55 |
+
|
56 |
+
## Inference
|
57 |
+
|
58 |
+
```python
|
59 |
+
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
60 |
+
from PIL import Image
|
61 |
+
import torch
|
62 |
+
import json
|
63 |
+
|
64 |
+
# Load model and processor
|
65 |
+
model_id = "google/paligemma-3b-pt-448"
|
66 |
+
peft_adapter_id = "riphunter7001x/PaliGemma3_FT_OCR"
|
67 |
+
|
68 |
+
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
69 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
70 |
+
model.load_adapter(peft_adapter_id).eval()
|
71 |
+
|
72 |
+
TORCH_DTYPE = model.dtype
|
73 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
74 |
+
|
75 |
+
# Load and process image
|
76 |
+
image = Image.open("image.jpg")
|
77 |
+
|
78 |
+
prefix = "<image>extract Document data in JSON format"
|
79 |
+
|
80 |
+
inputs = processor(
|
81 |
+
text=prefix,
|
82 |
+
images=image,
|
83 |
+
return_tensors="pt"
|
84 |
+
).to(TORCH_DTYPE).to(DEVICE)
|
85 |
+
|
86 |
+
prefix_length = inputs["input_ids"].shape[-1]
|
87 |
+
|
88 |
+
with torch.inference_mode():
|
89 |
+
generation = model.generate(**inputs, max_new_tokens=512, do_sample=False)
|
90 |
+
generation = generation[0][prefix_length:]
|
91 |
+
decoded = processor.decode(generation, skip_special_tokens=True)
|
92 |
+
print(json.dumps(json.loads(decoded), indent=4))
|
93 |
+
```
|
94 |
+
|
95 |
+
This code loads the fine-tuned PaliGemma model, processes an input image, and extracts document data in JSON format.
|