AF3-squeeze_features函数解读
目录
AF3 squeeze_features函数解读
AlphaFold3
data_transforms 模块的
squeeze_features 函数的作用
去除
蛋白质特征张量中不必要的单维度(singleton dimensions)和重复维度
,以使其适配
AlphaFold3
预期的输入格式。
源代码:
def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features."""
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
for k in [
"domain_name",
"msa",
"num_alignments",
"seq_length",
"sequence",
"superfamily",
"deletion_matrix",
"resolution",
"between_segment_residues",
"residue_index",
"template_all_atom_mask",
]:
if k in protein:
final_dim = protein[k].shape[-1]
if isinstance(final_dim, int) and final_dim == 1:
if torch.is_tensor(protein[k]):
protein[k] = torch.squeeze(protein[k], dim=-1)
else:
protein[k] = np.squeeze(protein[k], axis=-1)
for k in ["seq_length", "num_alignments"]:
if k in protein:
protein[k] = protein[k][0]
return protein
源码解读:
- 该函数接收
protein
(一个 包含蛋白质特征的字典 )作为输入。 - 主要任务:
- 将 one-hot
aatype
转换为索引表示 。 - 移除 shape 为
(N, ..., 1)
的单维度 。 - 提取
seq_length
和num_alignments
的实际数值 。
- 将 one-hot
Step 1: 处理 aatype
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
- 输入
aatype
(氨基酸类型)通常是 one-hot 编码 - 通过
torch.argmax(..., dim=-1)
获取 索引 - 目的
:简化
aatype
的数据表示,使其直接存储氨基酸索引,而不是 one-hot 矩阵。
Step 2: 移除单维度
for k in [
"domain_name",
"msa",
"num_alignments",
"seq_length",
"sequence",
"superfamily",
"deletion_matrix",
"resolution",
"between_segment_residues",
"residue_index",
"template_all_atom_mask",
]:
if k in protein:
final_dim = protein[k].shape[-1] # 获取最后一维的大小
if isinstance(final_dim, int) and final_dim == 1:
if torch.is_tensor(protein[k]):
protein[k] = torch.squeeze(protein[k], dim=-1) # 去掉单维度
else:
protein[k] = np.squeeze(protein[k], axis=-1)
- 遍历多个
protein
特征字段 ,检查它们是否存在。 - 如果最后一维
final_dim
为1
,说明这个维度是 无意义的单维度 ,需要去除:- 如果是
PyTorch 张量
(
torch.Tensor
),使用torch.squeeze(dim=-1)
。 - 如果是
NumPy 数组
,使用
np.squeeze(axis=-1)
。
- 如果是
PyTorch 张量
(
Step 3: 处理 seq_length
和 num_alignments
for k in ["seq_length", "num_alignments"]:
if k in protein:
protein[k] = protein[k][0]
seq_length
和
num_alignments
可能是
列表或张量
,但它们的数值其实是一个单独的整数,因此需要转换成
标量值
。
结论
- 1️⃣
- **转换
aatype
**- 从 one-hot 编码 转换成 索引表示 。
- 2️⃣
- 移除无用的单维度
- 让
msa
,resolution
,deletion_matrix
等数据符合 AlphaFold3 预期格式。 - 3️⃣
- **转换
seq_length
- 和
num_alignments
- 为标量**
- 确保它们不会以张量形式存在,而是整数。
💡 最终作用:保证输入数据的维度符合 AlphaFold3 训练时的输入要求,提高数据处理效率。