tensor.permute
pytorch维度转换 由于我们读取图片时,一般读取的形状是(x, y, c) 然而我们输入时,却常常需要(c, x, y) 所以就需要调整维度
代码:
import cv2
import numpy as np
import torch
def reshape_by_cv2(image):
b, g, r = cv2.split(image)
ans = np.array([b, g, r], dtype=np.uint8)
return ans
def reshape_by_torch(image):
image = torch.tensor(image, dtype=torch.uint8)
image = image.permute((2, 0, 1))
return image
def show_channle(image, channel):
image_channle = image[channel]
cv2.imshow("{0}".format(channel), image_channle)
cv2.waitKey()
if __name__ == '__main__':
image = cv2.imread("A.jpg")
print(np.shape(image))
image1 = reshape_by_cv2(image)
image2 = reshape_by_torch(image)
print(image1.shape)
show_channle(image1, 0)
print(image2.shape)
show_channle(image2.numpy(), 0)