This commit is contained in:
兰涛 2024-08-14 16:53:53 +08:00 committed by GitHub
commit 0de4c4dfcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
from starlette.datastructures import UploadFile as StarletteUploadFile
import uvicorn import uvicorn
from fastapi import FastAPI, File, UploadFile, HTTPException, Form from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from typing import Optional, Union from typing import Optional, Union
@ -7,8 +8,6 @@ from services import ocr_service
app = FastAPI() app = FastAPI()
from starlette.datastructures import UploadFile as StarletteUploadFile
async def decode_image(image: Union[UploadFile, StarletteUploadFile, str, None]) -> bytes: async def decode_image(image: Union[UploadFile, StarletteUploadFile, str, None]) -> bytes:
if image is None: if image is None:
@ -24,7 +23,8 @@ async def decode_image(image: Union[UploadFile, StarletteUploadFile, str, None])
image = image.split(',')[1] image = image.split(',')[1]
return base64.b64decode(image) return base64.b64decode(image)
except: except:
raise HTTPException(status_code=400, detail="Invalid base64 string") raise HTTPException(
status_code=400, detail="Invalid base64 string")
else: else:
raise HTTPException(status_code=400, detail="Invalid image input") raise HTTPException(status_code=400, detail="Invalid image input")
@ -38,11 +38,12 @@ async def ocr_endpoint(
png_fix: bool = Form(False) png_fix: bool = Form(False)
): ):
try: try:
if file.size == 0 and image is None: if file and file.file._file is None and image is None:
return APIResponse(code=400, message="Either file or image must be provided") return APIResponse(code=400, message="Either file or image must be provided")
image_bytes = await decode_image(file or image) image_bytes = await decode_image(file or image)
result = ocr_service.ocr_classification(image_bytes, probability, charsets, png_fix) result = ocr_service.ocr_classification(
image_bytes, probability, charsets, png_fix)
return APIResponse(code=200, message="Success", data=result) return APIResponse(code=200, message="Success", data=result)
except Exception as e: except Exception as e:
return APIResponse(code=500, message=str(e)) return APIResponse(code=500, message=str(e))
@ -57,12 +58,12 @@ async def slide_match_endpoint(
simple_target: bool = Form(False) simple_target: bool = Form(False)
): ):
try: try:
if (background is None and target is None) or (background_file.size == 0 and target_file.size == 0): if (background is None and background_file and background_file.file._file is None) or (target is None and target_file and target_file.file._file is None):
return APIResponse(code=400, message="Both target and background must be provided") return APIResponse(code=400, message="Both target and background must be provided")
target_bytes = await decode_image(target_file or target) target_bytes = await decode_image(target_file or target)
background_bytes = await decode_image(background_file or background) background_bytes = await decode_image(background_file or background)
result = ocr_service.slide_match(target_bytes, background_bytes, simple_target) result = ocr_service.slide_match(
target_bytes, background_bytes, simple_target)
return APIResponse(code=200, message="Success", data=result) return APIResponse(code=200, message="Success", data=result)
except Exception as e: except Exception as e:
return APIResponse(code=500, message=str(e)) return APIResponse(code=500, message=str(e))
@ -74,7 +75,7 @@ async def detection_endpoint(
image: Optional[str] = Form(None) image: Optional[str] = Form(None)
): ):
try: try:
if file.size == 0 and image is None: if file and file.file._file is None and image is None:
return APIResponse(code=400, message="Either file or image must be provided") return APIResponse(code=400, message="Either file or image must be provided")
image_bytes = await decode_image(file or image) image_bytes = await decode_image(file or image)