tensor.permute

pytorch维度转换 由于我们读取图片时,一般读取的形状是(x, y, c) 然而我们输入时,却常常需要(c, x, y) 所以就需要调整维度

代码:

import cv2
import numpy as np
import torch

def reshape_by_cv2(image):
    b, g, r = cv2.split(image)
    ans = np.array([b, g, r], dtype=np.uint8)
    return ans

def reshape_by_torch(image):
    image = torch.tensor(image, dtype=torch.uint8)
    image = image.permute((2, 0, 1))
    return image

def show_channle(image, channel):
    image_channle = image[channel]
    cv2.imshow("{0}".format(channel), image_channle)
    cv2.waitKey()

if __name__ == '__main__':
    image = cv2.imread("A.jpg")
    print(np.shape(image))
    image1 = reshape_by_cv2(image)
    image2 = reshape_by_torch(image)
    print(image1.shape)
    show_channle(image1, 0)
    print(image2.shape)
    show_channle(image2.numpy(), 0)



文章目录