常用类

这里总结一些频繁用到的支持类。

from dataclasses import dataclassfrom ..utils import BaseOutputfrom collections import OrderedDictclass BaseOutput(OrderedDict):...@dataclassclass Unet2DOutput(BaseOutput):"""The output of [`Unet2DModel`].Args:sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):The hidden states output from the last layer of the model."""sample: torch.FloatTensor

BaseOutput继承自OrderedDict,可以记住数据插入的顺序。BaseOutput这个类是所有模型输出的基类。models\unet_2d.py中就定义了Unet2DOutput做为该模型的输出类。且还用了dataclass装饰符,表明这个类只承载数据输出的作用。

from .modeling_utils import ModelMixinfrom ..configuration_utils import ConfigMixin, register_to_configclass Unet2DModel(ModelMixin, ConfigMixin):"""A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped otuput."""...

unet

Unet2DModel

主体由down_blocks, mid_blocks, up_blocks三块组成。输入除了sample,还有time_embedding和label_embedding。

down_blocks
mid_blocks
up_blocks
embeddings
forward