PYTHON으로 딥러닝하기
Custom Model TorchServing 성공기
euni_joa
2022. 7. 14. 15:42
반응형
지난 포스팅에서 Flask 로 모델을 서빙해보며 간단하게 API 에 대해 알 수 있었다.
제목이 성공기인 이유.. 저번에 TorchServe, Docker 등 시도했지만 실패함 :(
이번 성공기를 기록하여 다른 사람들은 어려움에 겪지 않도록!!
https://everyday-deeplearning.tistory.com/entry/Pytorch-serving-with-Flask
https://github.com/jeremiahschung/ghactions
Custom Pytorch Model Serving!
1. train model # def Model(nn.Module)
model.py
## >> model.py
## custom model
import torch
import torch.nn as nn
class ClassificationModel(nn.Module):
# euni: init params default 있어야 함
def __init__(self, class_num=5, input_shape=(3, 224, 224), dim=128, rate=0.1):
super(ClassificationModel, self).__init__()
## euni: padding='same'
self.conv2d = nn.Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.ReLU = nn.ReLU()
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(input_shape[1] * input_shape[2] * 2, dim)
self.dropout = nn.Dropout(p=rate, inplace=False)
self.linear2 = nn.Linear(dim, class_num)
def forward(self, inputs, training=False): # 맞는지 모르겠음
embedding = self.conv2d(inputs)
embedding = self.ReLU(embedding)
embedding = self.flatten(embedding)
embedding = self.linear1(embedding)
embedding = self.ReLU(embedding)
embedding = self.dropout(embedding)
embedding = self.linear2(embedding)
return embedding
# class_num = 5
# input_shape = (3, 224, 224)
# classification_model = ClassificationModel(class_num=class_num, input_shape=input_shape, rate=0.2)
## train했다고 가정
# classification_model
2. state_dict
# torch.save(model.state_dict(), "custom_model.pt")
## train 후 저장
torch.save(classification_model.state_dict(), 'classification_model.pt')
3. custom handler (input image preprocessing & model inference postprocess)
https://github.com/pytorch/serve/tree/master/examples/image_classifier/mnist
handler.py
# >> handler.py
from torchvision import transforms
from ts.torch_handler.image_classifier import ImageClassifier
from torch.profiler import ProfilerActivity
class BinaryClassifier(ImageClassifier):
"""
MNISTDigitClassifier handler class. This handler extends class ImageClassifier from image_classifier.py, a
default handler. This handler takes an image and returns the number in that image.
Here method postprocess() has been overridden while others are reused from parent class.
"""
# euni: preprocessing
image_processing = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
])
def __init__(self):
super(BinaryClassifier, self).__init__()
self.profiler_args = {
"activities" : [ProfilerActivity.CPU],
"record_shapes": True,
}
def postprocess(self, data):
"""The post process of MNIST converts the predicted output response to a label.
Args:
data (list): The predicted output from the Inference with probabilities is passed
to the post-process function
Returns:
list : A list of dictionaries with predictions and explanations is returned
"""
return data.argmax(1).tolist()
4. torch-model-archiver -> .mar file
$ git clone https://github.com/pytorch/serve.git
$ cd serve # 이동해서 해야함
$ python ./ts_scripts/install_dependencies.py # requirements..(cpu version)
$ cd ..
$ pip install torchserve torch-model-archiver torch-workflow-archiver
$ mkdir model_store # .mar file 위치
$ torch-model-archiver --model-name binary_classification --version 1.0 --model-file model.py --serialized-file classification_model.pt --export-path model_store --handler handler.py
5. torchserve --start ...
$ torchserve --start --model-store model_store --models binary_classification.mar
6. inference test
$ curl http://127.0.0.1:8080/predictions/binary_classification -T tmp.jpg
# $ torchserve --stop
[issue 503]
{
"code": 503,
"type": "InternalServerException",
"message": "Prediction failed"
}
→ solution
1. model or handler code issue check!
2. 같은 Port가 쓰이고 있을 수 있음 → port 변경
https://github.com/pytorch/serve/blob/master/docs/configuration.md#other-properties
$ grep 8080 /etc/services
- Create/update config.properties file
enable_envvars_config=true
inference_address=http://127.0.0.1:8443
management_address=http://127.0.0.1:8444
metrics_address=http://127.0.0.1:8445
- restart serve
torchserve --start --model-store model_store --models binary_classification=binary_classification.mar --ts-config config.properties
- inference test
$ curl http://127.0.0.1:8443/predictions/binary_classification -T tmp.jpg
- 참고 snapshot serving issue시 logs 폴더 지우고 다시 도전해보기 😢
반응형