반응형

참고: https://tutorials.pytorch.kr/intermediate/flask_rest_api_tutorial.html

장점

  1. 데이터의 전, 후 처리를 할 수 있다.
  2. 쉽게 API를 사용할 수 있어 범용성이 좋다.

 

Flask

아주 가벼운 웹프레임워크로 비교적 쉽게 배워서 사용할 수 있다.

$ pip install flask

 

1. 간단하게 웹서버 구성하기

from flask import Flask
# app = Flask("test")  # 설치 test용
app = Flask(__name__)

@app.route("/ping", methods=['GET'])
def ping():
    return "pong"

if __name__ == '__main__':
    app.run()

 

2. API 실행하기

$ Flask_APP=app.py FLASK_DEBUG=1 flask run

* 참고: window에서 실행 $ python app.py

 

3. 응답 확인하기

1) 웹 브라우저로 확인하기 : http://localhost:5000/ping 에 접속하면 pong 이 표시됨

2) httpie로 확인하기 : http -v GET http://localhost:5000/ping

3) python으로 확인하기

import requests

resp = requests.get("http://localhost:5000/ping")

 


Pytorch Rest API 배포하기

 

1. pytorch 모델 API 서버에 통합하기

from flask import Flask, jsonify, request
from PIL import Image
import torch
import torchvision.transforms as imtransforms


app = Flask(__name__)

# custom model 선언 & weight load
PATH = '{Mymodel_pretrained_weight}.pt'
model = Mymodel()
model.load_state_dict(torch.load(PATH, map_location=device), strict=False)
model.eval()


def transform_image(image_bytes):
    image = Image.open(io.BytesIO(image_bytes))  # byte file open
    image = imtransforms.Resize((imsize, imsize))(image)
    image = imtransforms.ToTensor()(image)
    return image.unsqueeze(0).to(device, torch.float)

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes)
    outputs = model(tensor)  # predict
    ...
    return class_id, class_name  # 자유롭게 return 가능


@app.route('/predict', methods=['POST'])
def predict():
    file = request.files['file']
    image_bytes = file.read()
    class_id, class_name = get_prediction(image_bytes)
    return jsonify({'class_id':class_id, 'clss_name': class_name})

 

2. API 실행하기

$ FLASK_ENV=development FLASK_APP=app.py flask run

flask run

 

3. predict 하기

import requests
import json

resp = requests.post("http://localhost:5000/predict", 
		       files={"file": open('{file_name}.jpg', 'rb')})

# model predict 결과 확인하기
# 200: success
print(json.loads(resp.content))
반응형

+ Recent posts