fix: 修复滑块验证的入参校验

This commit is contained in:
兰涛 2024-08-14 16:53:24 +08:00
parent fd8318b559
commit ed139fbedc

View File

@ -1,3 +1,4 @@
from starlette.datastructures import UploadFile as StarletteUploadFile
import uvicorn import uvicorn
from fastapi import FastAPI, File, UploadFile, HTTPException, Form from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from typing import Optional, Union from typing import Optional, Union
@ -7,8 +8,6 @@ from services import ocr_service
app = FastAPI() app = FastAPI()
from starlette.datastructures import UploadFile as StarletteUploadFile
async def decode_image(image: Union[UploadFile, StarletteUploadFile, str, None]) -> bytes: async def decode_image(image: Union[UploadFile, StarletteUploadFile, str, None]) -> bytes:
if image is None: if image is None:
@ -24,7 +23,8 @@ async def decode_image(image: Union[UploadFile, StarletteUploadFile, str, None])
image = image.split(',')[1] image = image.split(',')[1]
return base64.b64decode(image) return base64.b64decode(image)
except: except:
raise HTTPException(status_code=400, detail="Invalid base64 string") raise HTTPException(
status_code=400, detail="Invalid base64 string")
else: else:
raise HTTPException(status_code=400, detail="Invalid image input") raise HTTPException(status_code=400, detail="Invalid image input")
@ -42,7 +42,8 @@ async def ocr_endpoint(
return APIResponse(code=400, message="Either file or image must be provided") return APIResponse(code=400, message="Either file or image must be provided")
image_bytes = await decode_image(file or image) image_bytes = await decode_image(file or image)
result = ocr_service.ocr_classification(image_bytes, probability, charsets, png_fix) result = ocr_service.ocr_classification(
image_bytes, probability, charsets, png_fix)
return APIResponse(code=200, message="Success", data=result) return APIResponse(code=200, message="Success", data=result)
except Exception as e: except Exception as e:
return APIResponse(code=500, message=str(e)) return APIResponse(code=500, message=str(e))
@ -57,12 +58,12 @@ async def slide_match_endpoint(
simple_target: bool = Form(False) simple_target: bool = Form(False)
): ):
try: try:
if (background is None and target is None) or (background_file.size == 0 and target_file.size == 0): if (background is None and background_file and background_file.file._file is None) or (target is None and target_file and target_file.file._file is None):
return APIResponse(code=400, message="Both target and background must be provided") return APIResponse(code=400, message="Both target and background must be provided")
target_bytes = await decode_image(target_file or target) target_bytes = await decode_image(target_file or target)
background_bytes = await decode_image(background_file or background) background_bytes = await decode_image(background_file or background)
result = ocr_service.slide_match(target_bytes, background_bytes, simple_target) result = ocr_service.slide_match(
target_bytes, background_bytes, simple_target)
return APIResponse(code=200, message="Success", data=result) return APIResponse(code=200, message="Success", data=result)
except Exception as e: except Exception as e:
return APIResponse(code=500, message=str(e)) return APIResponse(code=500, message=str(e))