Torch和Numpy——查看形状类型
输入
import numpy as np
import torch
a = np.array([[1,2],[3,4]])
print(a.shape,np.shape(a),a.dtype)
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')
a = a.astype(np.float32)
print(a.dtype)
print('***************************************************')
b = torch.tensor([[1,2],[3,4]])
print(b.shape,b.size(),b.type())
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
b = b.float()
print(b.dtype)
输出
(2, 2) (2, 2) int32
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
float32
***************************************************
torch.Size([2, 2]) torch.Size([2, 2]) torch.LongTensor
&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&
torch.float32