Torch和Numpy——查看形状类型

输入

import numpy as np
import torch

a = np.array([[1,2],[3,4]])
print(a.shape,np.shape(a),a.dtype)
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')

a = a.astype(np.float32)
print(a.dtype)
print('***************************************************')

b = torch.tensor([[1,2],[3,4]])
print(b.shape,b.size(),b.type())
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")

b = b.float()
print(b.dtype)

输出

(2, 2) (2, 2) int32
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
float32
***************************************************
torch.Size([2, 2]) torch.Size([2, 2]) torch.LongTensor
&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&
torch.float32
文章目录