PYTHON으로 딥러닝하기

Model Conversion between Tensorflow and Pytorch | From TF To Torch

euni_joa 2022. 7. 12. 18:14
반응형

Test 의견

: 완벽 호환 안될 수 있음. custom model이라면 이슈를 미리 고려해서 모델 생성 후 적용해야할듯

TensorFlow 2 PyTorch

  1. 환경설정

     $ pip install tensorflow
     $ conda install pytorch torchvision cpuonly -c pytorch
    
     $ conda install -c conda-forge onnx
     $ pip install tf2onnx
     $ pip install onnx2pytorch
  2. TF Model 생성 !!주의!! GlobalAveragePooling2D, GlobalMaxPooling2D 등 pytorch에 없는 layer는… 에러 남!

     import tensorflow as tf
     class_num = 5
    
     class ClassificationModel(tf.keras.Model):
         def __init__(self, class_num, dim=128, rate=0.1):
             super(ClassificationModel, self).__init__()
             self.conv2d = tf.keras.layers.Conv2D(2, 3, padding='same', activation='relu')
    
             self.dense1 = tf.keras.layers.Dense(dim, activation='relu')
             self.dense2 = tf.keras.layers.Dense(class_num, activation='softmax', name='output')
    
             self.flatten = tf.keras.layers.Flatten()
             self.dropout = tf.keras.layers.Dropout(rate)
    
         def call(self, inputs):
    
             embedding = self.conv2d(inputs)
    
             embedding = self.flatten(embedding)
    
             embedding = self.dense1(embedding)
             embedding = self.dropout(embedding)
    
             embedding = self.dense2(embedding)
    
             return embedding
    
     input_shape = (224, 224, 3)
     classification_model = ClassificationModel(class_num=class_num, rate=0.2)
    
     ## sample test
     temp_input = tf.random.uniform(input_shape, dtype=tf.float32, minval=0, maxval=256)
     output = classification_model(tf.expand_dims(temp_input, 0))
    
     output.shape # TensorShape([1, 5])
    
     # model architecture & params 확인
     classification_model.build((None, 224, 224, 3))
     classification_model.summary()
    
     classification_model.compile(optimizer='adam',
                              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
                              metrics=['categorical_accuracy'])
    
     ## !! 사실은 train 해야함 !!
    
     model_save_path = 'image_classification_model'
     tf.keras.models.save_model(classification_model, model_save_path, include_optimizer=False)
  3. tf2onnx

     $ python -m tf2onnx.convert --saved-model image_classification_model --output image_classification_model.onnx
  4. ONNX inference Test

     import numpy as np
     import onnxruntime as ort
    
     img_path = 'tmp224.npy'
     img = np.load(img_path)  # input shape와 맞춰서 저장해놓음
     img = (img/255.0).astype('float32')  # input scale
     img = np.expand_dims(img, 0)
    
     sess_ort = ort.InferenceSession('image_classification_model.onnx')
    
     res = sess_ort.run(None, input_feed={sess_ort.get_inputs()[0].name: img})
    
     ## res
     # [array([[0.24008103, 0.19883673, 0.1655813 , 0.20317516, 0.19232577]],
     #       dtype=float32)]
  5. onnx2pytorch

     import onnx
     from onnx2pytorch import ConvertModel
    
     onnx_model = onnx.load('image_classification_model.onnx')
     pytorch_model = ConvertModel(onnx_model)
     pytorch_model
    
     # ConvertModel(
     #   (Transpose_StatefulPartitionedCall/classification_model_2/conv2d_3/Conv2D__6:0): Transpose()
     #   (Conv_StatefulPartitionedCall/classification_model_2/conv2d_3/Conv2D:0): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     #   (Relu_StatefulPartitionedCall/classification_model_2/conv2d_3/Relu:0): ReLU(inplace=True)
     #   (Transpose_StatefulPartitionedCall/classification_model_2/conv2d_3/Conv2D__8:0): Transpose()
     #   (Reshape_StatefulPartitionedCall/classification_model_2/flatten_1/Reshape:0): Reshape(shape=[    -1 100352])
     #   (MatMul_StatefulPartitionedCall/classification_model_2/dense_3/MatMul:0): Linear(in_features=100352, out_features=128, bias=False)
     #   (Relu_StatefulPartitionedCall/classification_model_2/dense_3/Relu:0): ReLU(inplace=True)
     #   (MatMul_StatefulPartitionedCall/classification_model_2/output/MatMul:0): Linear(in_features=128, out_features=5, bias=False)
     #   (Softmax_output_1): Softmax(dim=-1)
     # )
     import torch
     torch.save(pytorch_model, 'image_classification_model.pth')
  6. pytorch inference test

    : 결과 같은지 체크

     pytorch_model = torch.load('image_classification_model.pth')
     res = pytorch_model(torch.Tensor(img)) # dummy_input.reshape(1, 224,224, 3))
    
     ## res
     # tensor([[0.2401, 0.1988, 0.1656, 0.2032, 0.1923]], grad_fn=<SoftmaxBackward0>)
반응형