Forward和call函数理解

Pytorch中的forward理解

我们在使用Pytorch的时候,模型训练时,不需要调用forward这个函数,只需要在实例化一个对象中传入对应的参数就可以自动调用forward函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20

class Module(nn.Module):
    def __init__(self):
        super().__init__()
        #......

    def forward(self,x):
        #......
        return x 

data = .......  # 输入数据

# 实例化一个对象
model = Module()

# 前向传播
model(data)
# 而不是使用下面的
# model.forward(data)

但是实际上model(data)是等价于model.forward(data),这是为什么?

forward函数

model(data)之所以等于model.forward(data),就是因为在类(class)中使用了__class__函数,对__class__ 的理解写在后面

1
2
3
4
5
6
7

class Student:
    def __call__(self):
        print("I can be called lick a function")

a = Student()
a()

输出结 I can be called lick a function

由上面的__class__函数可知,我们可以将forward函数放到__class__函数中进行调用:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class Student:
    def __call__(self,param):
        print("I can called link a function")
        print("传入参数的类型是:{} 值为: {}".format(type(param),param))

        res = self.forward(param)
        return res 

    def forward(self,input_):
        print("forward 函数被调用了")

        print("in forward,传入参数类型是:{} 值为: {}.".format(type(input_),input_))
        return input_

a = Student()

input_param = a('data')
print("对象a传入的参数是:",input_param)

输出结果:

I can called like a function
传入参数的类型是:<class 'str'>   值为: data
forward 函数被调用了
in  forward, 传入参数类型是:<class 'str'>  值为: data
对象a传入的参数是: data

为什么model(data)等价于model.forward(data),是因为__call__函数中调用了forward函数.


__class__函数

该方法的功能类似于在类中重载()运算符,使得类实例对象可以像调用普通函数那样,以"对象名()"的形式使用. 作用:为了将类的实例对象变为可调用对象.

1
2
3
4
5
6
7
class CLanguage:
    def __call__(self,name,add):
        print("调用__call__()方法",name,add)

clangs = CLanguage()
clangs("程序执行")

程序执行结果为: 调用__call__()方法 程序执行

通过在CLanguage类中实现__call__()方法,🙆clangs实例对象变为 可调用对象

Python中,凡是可以将()直接应用到自身并执行,都称为可调用对象.可调用对象包括自定义函数/Python内置函数以及本节所说的类实例对象.

对于可调用对象,实际上"名称()"可以理解为是名称.__call__()的简写.上面程序中定义的clangs实例对象为例,其最后一行可以改写为: clangs.__call__("程序执行")

程序的执行结果与之前完全相同.

__call__弥补hasattr()函数的短板: hasattr()函数的用法,该函数的功能是查找类的实例对象中是否包含指定名称的属性或者方法,但该函数有一个缺陷,即无法判定该指定的名称,是否是类属性还是类方法.

要解决这个问题,我们可以借助可调用对象的概念. 类实例对象包含的方法,其实也属于可调用对象,但类属性却不是.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
class CLanguage:
    def __init__(self):
        self.name = "Jone"
        self.add = "Messi"
    def say(self):
        print("Learning Python ")

clangs = CLanguage()
if hasattr(clangs,"name"):
    print(hasattr(clangs.name,"__class__"))
print("*********")
if hasattr(clangs,"say"):
    print(hasattr(clangs.say,"__call__"))

程序执行结果:

False
**********
True

由于name是类属性,它没有以'call'为名的'call()'方法;而say是类方法,它是可调用对象,因此它有__call__()方法.

其他实例

  1. 函数本身可以被调用
1
2
3
def func():
    pass 
print(callable(func))

输出结果: True

  1. 类本身可以被调用,主要用作生成实例化对象:
1
2
3
4
class Fun:
    def __init(self):
        pass 
print(callable(Fun))

输出结果: True

  1. 类的实例化对象无法被调用:
1
2
3
4
5
6
class Fun:
    def __init__(self):
        pass

a = Fun()
print(callable(a))

输出结果: False

  1. 通过增加__call__()函数实例化对象变为可调用
1
2
3
4
5
6
7
8
9
class Fun:
    def __init__(self):
        pass 
    def __call__(self,*args,**kwargs):
        pass

a = Fun()
print(callable(a))

输出结果: True


updatedupdated2021-05-242021-05-24