Test 의견
: 완벽 호환 안될 수 있음. custom model이라면 이슈를 미리 고려해서 모델 생성 후 적용해야할듯
TensorFlow 2 PyTorch
환경설정
$ pip install tensorflow $ conda install pytorch torchvision cpuonly -c pytorch $ conda install -c conda-forge onnx $ pip install tf2onnx $ pip install onnx2pytorch
TF Model 생성 !!주의!! GlobalAveragePooling2D, GlobalMaxPooling2D 등 pytorch에 없는 layer는… 에러 남!
import tensorflow as tf class_num = 5 class ClassificationModel(tf.keras.Model): def __init__(self, class_num, dim=128, rate=0.1): super(ClassificationModel, self).__init__() self.conv2d = tf.keras.layers.Conv2D(2, 3, padding='same', activation='relu') self.dense1 = tf.keras.layers.Dense(dim, activation='relu') self.dense2 = tf.keras.layers.Dense(class_num, activation='softmax', name='output') self.flatten = tf.keras.layers.Flatten() self.dropout = tf.keras.layers.Dropout(rate) def call(self, inputs): embedding = self.conv2d(inputs) embedding = self.flatten(embedding) embedding = self.dense1(embedding) embedding = self.dropout(embedding) embedding = self.dense2(embedding) return embedding input_shape = (224, 224, 3) classification_model = ClassificationModel(class_num=class_num, rate=0.2) ## sample test temp_input = tf.random.uniform(input_shape, dtype=tf.float32, minval=0, maxval=256) output = classification_model(tf.expand_dims(temp_input, 0)) output.shape # TensorShape([1, 5]) # model architecture & params 확인 classification_model.build((None, 224, 224, 3)) classification_model.summary() classification_model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False), metrics=['categorical_accuracy']) ## !! 사실은 train 해야함 !! model_save_path = 'image_classification_model' tf.keras.models.save_model(classification_model, model_save_path, include_optimizer=False)
tf2onnx
$ python -m tf2onnx.convert --saved-model image_classification_model --output image_classification_model.onnx
ONNX inference Test
import numpy as np import onnxruntime as ort img_path = 'tmp224.npy' img = np.load(img_path) # input shape와 맞춰서 저장해놓음 img = (img/255.0).astype('float32') # input scale img = np.expand_dims(img, 0) sess_ort = ort.InferenceSession('image_classification_model.onnx') res = sess_ort.run(None, input_feed={sess_ort.get_inputs()[0].name: img}) ## res # [array([[0.24008103, 0.19883673, 0.1655813 , 0.20317516, 0.19232577]], # dtype=float32)]
onnx2pytorch
import onnx from onnx2pytorch import ConvertModel onnx_model = onnx.load('image_classification_model.onnx') pytorch_model = ConvertModel(onnx_model)
pytorch_model # ConvertModel( # (Transpose_StatefulPartitionedCall/classification_model_2/conv2d_3/Conv2D__6:0): Transpose() # (Conv_StatefulPartitionedCall/classification_model_2/conv2d_3/Conv2D:0): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) # (Relu_StatefulPartitionedCall/classification_model_2/conv2d_3/Relu:0): ReLU(inplace=True) # (Transpose_StatefulPartitionedCall/classification_model_2/conv2d_3/Conv2D__8:0): Transpose() # (Reshape_StatefulPartitionedCall/classification_model_2/flatten_1/Reshape:0): Reshape(shape=[ -1 100352]) # (MatMul_StatefulPartitionedCall/classification_model_2/dense_3/MatMul:0): Linear(in_features=100352, out_features=128, bias=False) # (Relu_StatefulPartitionedCall/classification_model_2/dense_3/Relu:0): ReLU(inplace=True) # (MatMul_StatefulPartitionedCall/classification_model_2/output/MatMul:0): Linear(in_features=128, out_features=5, bias=False) # (Softmax_output_1): Softmax(dim=-1) # )
import torch torch.save(pytorch_model, 'image_classification_model.pth')
pytorch inference test
: 결과 같은지 체크
pytorch_model = torch.load('image_classification_model.pth') res = pytorch_model(torch.Tensor(img)) # dummy_input.reshape(1, 224,224, 3)) ## res # tensor([[0.2401, 0.1988, 0.1656, 0.2032, 0.1923]], grad_fn=<SoftmaxBackward0>)
'PYTHON으로 딥러닝하기' 카테고리의 다른 글
Python 더 나아가기 | 실수하기 쉬운 Python 문법 & type hits (0) | 2022.08.03 |
---|---|
Custom Model TorchServing 성공기 (0) | 2022.07.14 |
Model Conversion between Tensorflow and Pytorch (0) | 2022.07.05 |
Hyperparameters tuning | Keras Tuner 튜토리얼 (0) | 2022.04.22 |
Custom Pytorch Model serving with Flask (0) | 2022.03.02 |