旋转位置编码-2
旋转位置编码 (2)
旋转位置编码常见的torch函数
导入必要的库
import torch
定义一个简单的函数来演示torch.view_as_complex和torch.view_as_real
def demo_view_as_complex_real():
创建一个形状为 (2, 3, 4, 2) 的张量,最后一个维度表示复数的实部和虚部
x = torch.randn(2, 3, 4, 2)
使用torch.view_as_complex将最后两个维度转换为复数
x_complex = torch.view_as_complex(x) print(“Original tensor shape:”, x.shape) print(“After view_as_complex shape:”, x_complex.shape)
使用torch.view_as_real将复数张量转换回实部和虚部
x_real = torch.view_as_real(x_complex) print(“After view_as_real shape:”, x_real.shape)
检查转换是否可逆
print(“Are the original and reconstructed tensors equal?”, torch.allclose(x, x_real)) demo_view_as_complex_real() Original tensor shape: torch.Size([2, 3, 4, 2]) After view_as_complex shape: torch.Size([2, 3, 4]) After view_as_real shape: torch.Size([2, 3, 4, 2]) Are the original and reconstructed tensors equal? True
定义一个简单的函数来演示torch.reshape和torch.flatten
def demo_reshape_flatten():
创建一个形状为 (2, 3, 4) 的张量
x = torch.randn(2, 3, 4)
使用torch.reshape改变张量的形状
x_reshaped = x.reshape(2, -1) # 将最后两个维度展平 print(“Original tensor shape:”, x.shape) print(“After reshape shape:”, x_reshaped.shape)
使用torch.flatten展平张量
x_flattened = x.flatten(start_dim=1) # 从第1维度开始展平 print(“After flatten shape:”, x_flattened.shape) demo_reshape_flatten() Original tensor shape: torch.Size([2, 3, 4]) After reshape shape: torch.Size([2, 12]) After flatten shape: torch.Size([2, 12]) xq = torch.randint(0, 10, (3, 4, 6)) # 创建一个形状为 (3, 4, 6) 的张量 print(“原始形状:”, xq.shape)
执行操作
result = xq.float().reshape(*xq.shape[:-1], -1, 2) # 转换为 float 并重塑形状 print(“转换后的形状:”, result.shape) 原始形状: torch.Size([3, 4, 6]) 转换后的形状: torch.Size([3, 4, 3, 2])
定义一个简单的函数来演示torch.view
def demo_view():
创建一个形状为 (2, 3, 4) 的张量
x = torch.randn(2, 3, 4, 4)
使用torch.view改变张量的形状
x_viewed = x.view(2, -1) # 将最后两个维度展平 print(“Original tensor shape:”, x.shape) print(“After view shape:”, x_viewed.shape) demo_view() Original tensor shape: torch.Size([2, 3, 4, 4]) After view shape: torch.Size([2, 48])
定义一个简单的函数来演示torch.type_as
def demo_type_as():
创建两个不同类型的张量
x = torch.randn(2, 3, dtype=torch.float32) y = torch.randn(2, 3, dtype=torch.float64)
使用torch.type_as将x的类型转换为与y相同
x_converted = x.type_as(y) print(“Original x dtype:”, x.dtype) print(“After type_as dtype:”, x_converted.dtype) demo_type_as() Original x dtype: torch.float32 After type_as dtype: torch.float64 import torch def unite_shape(pos_cis, x): print(f"输入 x 的形状: {x.shape}") print(f"输入 pos_cis 的形状: {pos_cis.shape}") ndim = x.ndim # 获取输入张量 x 的维度数 print(f"x 的维度数 (ndim): {ndim}") assert 0 <= 1 < ndim # 确保 x 至少有两个维度 print(“断言 0 <= 1 < ndim 通过”) assert pos_cis.shape == (x.shape[1], x.shape[-1]) # 检查 pos_cis 的形状是否匹配 print(“断言 pos_cis.shape == (x.shape[1], x.shape[-1]) 通过”) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # 构建新的形状 print(f"构建的新形状 (shape): {shape}") result = pos_cis.view(*shape) # 将 pos_cis 重塑为新的形状 print(f"重塑后的 pos_cis 形状: {result.shape}") return result
示例数据
x = torch.randn(10, 20, 30) # 创建一个形状为 (10, 20, 30) 的张量 pos_cis = torch.randn(20, 30) # 创建一个形状为 (20, 30) 的张量
调用函数
output = unite_shape(pos_cis, x) 输入 x 的形状: torch.Size([10, 20, 30]) 输入 pos_cis 的形状: torch.Size([20, 30]) x 的维度数 (ndim): 3 断言 0 <= 1 < ndim 通过 断言 pos_cis.shape == (x.shape[1], x.shape[-1]) 通过 构建的新形状 (shape): [1, 20, 30] 重塑后的 pos_cis 形状: torch.Size([1, 20, 30])
代码解释
demo_view_as_complex_real
函数:
- 该函数展示了如何使用
torch.view_as_complex
将张量的最后两个维度转换为复数,并使用torch.view_as_real
将其转换回实部和虚部。 - 通过
torch.allclose
检查转换是否可逆。
demo_reshape_flatten
函数:
- 该函数展示了如何使用
torch.reshape
和torch.flatten
来改变张量的形状。 reshape
可以将张量重新调整为指定的形状,而flatten
则将张量展平为指定的维度。
xq
张量的操作:
- 该部分代码展示了如何将一个形状为
(3, 4, 6)
的张量转换为float
类型,并重塑为(3, 4, 3, 2)
的形状。
demo_view
函数:
- 该函数展示了如何使用
torch.view
来改变张量的形状。view
类似于reshape
,但它要求张量在内存中是连续的。
demo_type_as
函数:
- 该函数展示了如何使用
torch.type_as
将一个张量的数据类型转换为与另一个张量相同。
unite_shape
函数:
- 该函数展示了如何根据输入张量
x
的形状来重塑pos_cis
张量。 - 通过
assert
语句确保输入张量的形状符合预期,并使用view
方法重塑pos_cis
的形状。 这些代码片段展示了 PyTorch 中常用的张量操作,包括形状变换、数据类型转换以及复数张量的处理。