mitlist/be/app/core/gemini.py

349 lines
14 KiB
Python

import logging
from typing import List
import google.generativeai as genai
from google.api_core import exceptions as google_exceptions
from app.config import settings
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
OCRProcessingError
)
logger = logging.getLogger(__name__)
gemini_flash_client = None
gemini_initialization_error = None
try:
if settings.GEMINI_API_KEY:
genai.configure(api_key=settings.GEMINI_API_KEY)
gemini_flash_client = genai.GenerativeModel(
model_name=settings.GEMINI_MODEL_NAME,
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
)
logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.")
else:
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
logger.error(gemini_initialization_error)
except Exception as e:
gemini_initialization_error = f"Failed to initialize Gemini AI client: {e}"
logger.exception(gemini_initialization_error)
gemini_flash_client = None
def get_gemini_client():
"""
Returns the initialized Gemini client instance.
Raises an exception if initialization failed.
"""
if gemini_initialization_error:
raise OCRServiceConfigError()
if gemini_flash_client is None:
raise OCRServiceConfigError()
return gemini_flash_client
OCR_ITEM_EXTRACTION_PROMPT = """
**ROLE & GOAL**
You are an expert AI assistant specializing in Optical Character Recognition (OCR) and structured data extraction. Your primary function is to act as a "Shopping List Digitizer."
Your goal is to meticulously analyze the provided image of a shopping list, which is likely handwritten, and convert it into a structured, machine-readable JSON format. You must be accurate, infer context where necessary, and handle the inherent ambiguities of handwriting and informal list-making.
**INPUT**
You will receive a single image (`[Image]`). This image contains a shopping list. It may be:
* Neatly written or very messy.
* On lined paper, a whiteboard, a napkin, or a dedicated notepad.
* Containing doodles, stains, or other visual noise.
* Using various formats (bullet points, numbered lists, columns, simple line breaks).
* could be in English or in German.
**CORE TASK: STEP-BY-STEP ANALYSIS**
Follow these steps precisely:
1. **Initial Image Analysis & OCR:**
* Perform an advanced OCR scan on the entire image to transcribe all visible text.
* Pay close attention to the spatial layout. Identify headings, columns, and line items. Note which text elements appear to be grouped together.
2. **Item Identification & Filtering:**
* Differentiate between actual list items and non-item elements.
* **INCLUDE:** Items intended for purchase.
* **EXCLUDE:** List titles (e.g., "GROCERIES," "Target List"), dates, doodles, unrelated notes, or stray marks. Capture the list title separately if one exists.
3. **Detailed Extraction for Each Item:**
For every single item you identify, extract the following attributes. If an attribute is not present, use `null`.
* `item_name` (string): The primary name of the product.
* **Standardize:** Normalize the name. (e.g., "B. Powder" -> "Baking Powder", "A. Juice" -> "Apple Juice").
* **Contextual Guessing:** If a word is poorly written, use the context of a shopping list to make an educated guess. (e.g., "Ciffee" is almost certainly "Coffee").
* `quantity` (number or string): The amount needed.
* If a number is present (e.g., "**2** milks"), extract the number `2`.
* If it's a word (e.g., "**a dozen** eggs"), extract the string `"a dozen"`.
* If no quantity is specified (e.g., "Bread"), infer a default quantity of `1`.
* `unit` (string): The unit of measurement or packaging.
* Examples: "kg", "lbs", "liters", "gallons", "box", "can", "bag", "bunch".
* Infer where possible (e.g., for "2 Milks," the unit could be inferred as "cartons" or "gallons" depending on regional context, but it's safer to leave it `null` if not explicitly stated).
* `notes` (string): Any additional descriptive text.
* Examples: "low-sodium," "organic," "brand name (Tide)," "for the cake," "get the ripe ones."
* `category` (string): Infer a logical category for the item.
* Use common grocery store categories: `Produce`, `Dairy & Eggs`, `Meat & Seafood`, `Pantry`, `Frozen`, `Bakery`, `Beverages`, `Household`, `Personal Care`.
* If the list itself has category headings (e.g., a "DAIRY" section), use those first.
* `original_text` (string): Provide the exact, unaltered text that your OCR transcribed for this entire line item. This is crucial for verification.
* `is_crossed_out` (boolean): Set to `true` if the item is struck through, crossed out, or clearly marked as completed. Otherwise, set to `false`.
**HANDLING AMBIGUITIES AND EDGE CASES**
* **Illegible Text:** If a line or word is completely unreadable, set `item_name` to `"UNREADABLE"` and place the garbled OCR attempt in the `original_text` field.
* **Abbreviations:** Expand common shopping list abbreviations (e.g., "OJ" -> "Orange Juice", "TP" -> "Toilet Paper", "AVOs" -> "Avocados", "G. Beef" -> "Ground Beef").
* **Implicit Items:** If a line is vague like "Snacks for kids," list it as is. Do not invent specific items.
* **Multi-item Lines:** If a line contains multiple items (e.g., "Onions, Garlic, Ginger"), split them into separate item objects.
**OUTPUT FORMAT**
Your final output MUST be a single JSON object with the following structure. Do not include any explanatory text before or after the JSON block.
```json
{
"list_title": "string or null",
"items": [
{
"item_name": "string",
"quantity": "number or string",
"unit": "string or null",
"category": "string",
"notes": "string or null",
"original_text": "string",
"is_crossed_out": "boolean"
}
],
"summary": {
"total_items": "integer",
"unread_items": "integer",
"crossed_out_items": "integer"
}
}
```
**EXAMPLE WALKTHROUGH**
* **IF THE IMAGE SHOWS:** A crumpled sticky note with the title "Stuff for tonight" and the items:
* `2x Chicken Breasts`
* `~~Baguette~~` (this item is crossed out)
* `Salad mix (bag)`
* `Tomatos` (misspelled)
* `Choc Ice Cream`
* **YOUR JSON OUTPUT SHOULD BE:**
```json
{
"list_title": "Stuff for tonight",
"items": [
{
"item_name": "Chicken Breasts",
"quantity": 2,
"unit": null,
"category": "Meat & Seafood",
"notes": null,
"original_text": "2x Chicken Breasts",
"is_crossed_out": false
},
{
"item_name": "Baguette",
"quantity": 1,
"unit": null,
"category": "Bakery",
"notes": null,
"original_text": "Baguette",
"is_crossed_out": true
},
{
"item_name": "Salad Mix",
"quantity": 1,
"unit": "bag",
"category": "Produce",
"notes": null,
"original_text": "Salad mix (bag)",
"is_crossed_out": false
},
{
"item_name": "Tomatoes",
"quantity": 1,
"unit": null,
"category": "Produce",
"notes": null,
"original_text": "Tomatos",
"is_crossed_out": false
},
{
"item_name": "Chocolate Ice Cream",
"quantity": 1,
"unit": null,
"category": "Frozen",
"notes": null,
"original_text": "Choc Ice Cream",
"is_crossed_out": false
}
],
"summary": {
"total_items": 5,
"unread_items": 0,
"crossed_out_items": 1
}
}
```
**FINAL INSTRUCTION**
If the image provided is not a shopping list or is completely blank/unintelligible, respond with a JSON object where the `items` array is empty and add a note in the `list_title` field, such as "Image does not appear to be a shopping list."
Now, analyze the provided image and generate the JSON output.
"""
async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "image/jpeg") -> List[str]:
"""
Uses Gemini Flash to extract shopping list items from image bytes.
Args:
image_bytes: The image content as bytes.
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
Returns:
A list of extracted item strings.
Raises:
OCRServiceConfigError: If the Gemini client is not initialized.
OCRQuotaExceededError: If API quota is exceeded.
OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
"""
try:
client = get_gemini_client() # Raises OCRServiceConfigError if not initialized
image_part = {
"mime_type": mime_type,
"data": image_bytes
}
prompt_parts = [
settings.OCR_ITEM_EXTRACTION_PROMPT,
image_part
]
logger.info("Sending image to Gemini for item extraction...")
response = await client.generate_content_async(prompt_parts)
if not response.candidates or not response.candidates[0].content.parts:
logger.warning("Gemini response blocked or empty.", extra={"response": response})
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
if finish_reason == 'SAFETY':
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
else:
raise OCRUnexpectedError()
raw_text = response.text
logger.info("Received raw text from Gemini.")
items = []
for line in raw_text.splitlines():
cleaned_line = line.strip()
if cleaned_line and len(cleaned_line) > 1:
items.append(cleaned_line)
logger.info(f"Extracted {len(items)} potential items.")
return items
except google_exceptions.GoogleAPIError as e:
logger.error(f"Gemini API Error: {e}", exc_info=True)
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
raise
except Exception as e:
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
raise OCRUnexpectedError()
class GeminiOCRService:
def __init__(self):
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
self.model = genai.GenerativeModel(
model_name=settings.GEMINI_MODEL_NAME,
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
)
except Exception as e:
logger.error(f"Failed to initialize Gemini client: {e}")
raise OCRServiceConfigError()
async def extract_items(self, image_data: bytes, mime_type: str = "image/jpeg") -> List[str]:
"""
Extract shopping list items from an image using Gemini Vision.
Args:
image_data: The image content as bytes.
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
Returns:
A list of extracted item strings.
Raises:
OCRServiceConfigError: If the Gemini client is not initialized.
OCRQuotaExceededError: If API quota is exceeded.
OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
"""
try:
image_parts = [{"mime_type": mime_type, "data": image_data}]
response = await self.model.generate_content_async(
contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts]
)
if not response.text:
logger.warning("Gemini response is empty")
raise OCRUnexpectedError()
if hasattr(response, 'candidates') and response.candidates and hasattr(response.candidates[0], 'finish_reason'):
finish_reason = response.candidates[0].finish_reason
if finish_reason == 'SAFETY':
safety_ratings = response.candidates[0].safety_ratings if hasattr(response.candidates[0], 'safety_ratings') else 'N/A'
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
items = []
for line in response.text.splitlines():
cleaned_line = line.strip()
if cleaned_line and len(cleaned_line) > 1 and not cleaned_line.startswith("Example"):
items.append(cleaned_line)
logger.info(f"Extracted {len(items)} potential items.")
return items
except google_exceptions.GoogleAPIError as e:
logger.error(f"Error during OCR extraction: {e}")
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
raise
except Exception as e:
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
raise OCRUnexpectedError()