riphunter7001x commited on
Commit
9612b7d
·
verified ·
1 Parent(s): 4f73a4f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +42 -3
README.md CHANGED
@@ -46,11 +46,50 @@ The following hyperparameters were used during training:
46
 
47
  ### Training results
48
 
49
-
50
-
51
  ### Framework versions
52
 
53
  - PEFT 0.14.0
54
  - Transformers 4.47.0
55
  - Pytorch 2.2.1+cu121
56
- - Tokenizers 0.21.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  ### Training results
48
 
 
 
49
  ### Framework versions
50
 
51
  - PEFT 0.14.0
52
  - Transformers 4.47.0
53
  - Pytorch 2.2.1+cu121
54
+ - Tokenizers 0.21.0
55
+
56
+ ## Inference
57
+
58
+ ```python
59
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
60
+ from PIL import Image
61
+ import torch
62
+ import json
63
+
64
+ # Load model and processor
65
+ model_id = "google/paligemma-3b-pt-448"
66
+ peft_adapter_id = "riphunter7001x/PaliGemma3_FT_OCR"
67
+
68
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, device_map="auto")
69
+ processor = AutoProcessor.from_pretrained(model_id)
70
+ model.load_adapter(peft_adapter_id).eval()
71
+
72
+ TORCH_DTYPE = model.dtype
73
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+
75
+ # Load and process image
76
+ image = Image.open("image.jpg")
77
+
78
+ prefix = "<image>extract Document data in JSON format"
79
+
80
+ inputs = processor(
81
+ text=prefix,
82
+ images=image,
83
+ return_tensors="pt"
84
+ ).to(TORCH_DTYPE).to(DEVICE)
85
+
86
+ prefix_length = inputs["input_ids"].shape[-1]
87
+
88
+ with torch.inference_mode():
89
+ generation = model.generate(**inputs, max_new_tokens=512, do_sample=False)
90
+ generation = generation[0][prefix_length:]
91
+ decoded = processor.decode(generation, skip_special_tokens=True)
92
+ print(json.dumps(json.loads(decoded), indent=4))
93
+ ```
94
+
95
+ This code loads the fine-tuned PaliGemma model, processes an input image, and extracts document data in JSON format.