diff --git a/app/main.py b/app/main.py index 532a4bc..7e24788 100644 --- a/app/main.py +++ b/app/main.py @@ -1,3 +1,4 @@ +from starlette.datastructures import UploadFile as StarletteUploadFile import uvicorn from fastapi import FastAPI, File, UploadFile, HTTPException, Form from typing import Optional, Union @@ -7,8 +8,6 @@ from services import ocr_service app = FastAPI() -from starlette.datastructures import UploadFile as StarletteUploadFile - async def decode_image(image: Union[UploadFile, StarletteUploadFile, str, None]) -> bytes: if image is None: @@ -24,7 +23,8 @@ async def decode_image(image: Union[UploadFile, StarletteUploadFile, str, None]) image = image.split(',')[1] return base64.b64decode(image) except: - raise HTTPException(status_code=400, detail="Invalid base64 string") + raise HTTPException( + status_code=400, detail="Invalid base64 string") else: raise HTTPException(status_code=400, detail="Invalid image input") @@ -38,11 +38,12 @@ async def ocr_endpoint( png_fix: bool = Form(False) ): try: - if file.size == 0 and image is None: + if file and file.file._file is None and image is None: return APIResponse(code=400, message="Either file or image must be provided") image_bytes = await decode_image(file or image) - result = ocr_service.ocr_classification(image_bytes, probability, charsets, png_fix) + result = ocr_service.ocr_classification( + image_bytes, probability, charsets, png_fix) return APIResponse(code=200, message="Success", data=result) except Exception as e: return APIResponse(code=500, message=str(e)) @@ -57,12 +58,12 @@ async def slide_match_endpoint( simple_target: bool = Form(False) ): try: - if (background is None and target is None) or (background_file.size == 0 and target_file.size == 0): + if (background is None and background_file and background_file.file._file is None) or (target is None and target_file and target_file.file._file is None): return APIResponse(code=400, message="Both target and background must be provided") - target_bytes = await decode_image(target_file or target) background_bytes = await decode_image(background_file or background) - result = ocr_service.slide_match(target_bytes, background_bytes, simple_target) + result = ocr_service.slide_match( + target_bytes, background_bytes, simple_target) return APIResponse(code=200, message="Success", data=result) except Exception as e: return APIResponse(code=500, message=str(e)) @@ -74,7 +75,7 @@ async def detection_endpoint( image: Optional[str] = Form(None) ): try: - if file.size == 0 and image is None: + if file and file.file._file is None and image is None: return APIResponse(code=400, message="Either file or image must be provided") image_bytes = await decode_image(file or image)