折耳猫为什么不能养| 甲亢是什么意思| 什么时间泡脚最好| 耳仓为什么是臭的| 补体c3偏低是什么意思| 鱼鳔是什么东西| 吃什么可以补阳气| 植物纤维是什么面料| 芒种是什么意思| 红米是什么米| 上头了是什么意思| 溥仪什么时候去世的| 痔疮什么感觉| 吸狗是什么意思| 开什么店好赚钱| 群体是什么意思| 喝老陈醋有什么好处| 大熊猫吃什么| 强调是什么意思| 京ag6是什么意思| 鼓包是什么意思| 什么在千里| 发蜡和发泥有什么区别| 全自动洗衣机不排水是什么原因| 礽是什么意思| 安宫牛黄丸为什么那么贵| pioneer是什么牌子| 消融是什么意思| 79属什么生肖| 膑是什么意思| 牙周炎是什么症状| 五级职员是什么级别| 日加立念什么| 判决书什么时候生效| 怀孕小肚子疼是什么原因| 第二学士学位是什么意思| 糖霜是什么| 什么是热感冒| 检察院是做什么的| 恶露是什么| 尿后余沥是什么意思| 脚指甲变白是什么原因| dan是什么意思| 阴道恶臭是什么原因| bcl是什么意思| 国务院秘书长什么级别| 1点到3点是什么时辰| 可转债是什么| 什么是低保| 属鸡的守护神是什么菩萨| 县团级是什么级别| 纵隔淋巴结转移是什么意思| 痱子粉什么牌子好| 常喝普洱茶有什么好处| 23333是什么意思| 长河落日圆什么意思| 怀孕前三个月要注意什么| 潴是什么意思| 前列腺增大伴钙化是什么意思| 检查乳腺挂什么科| 苔藓是什么意思| 甜五行属什么| 丝状疣是什么原因长出来的| 下边瘙痒是什么原因| 卡介苗是预防什么| 嘴巴里甜甜的是什么原因| 骨折补钙吃什么钙片好| 为什么会发生地震| 转化是什么意思| 为什么胃疼| classy是什么意思| 怀孕前三个月应该注意什么| 张飞穿针歇后语下一句是什么| 吃什么补白蛋白最快最好| 贝贝什么意思| 生姜能治什么病| 为什么会下冰雹| 心慌吃什么药效果好| 省政协委员是什么级别| 充电玩手机有什么危害| 小意思是什么意思| 白蛋白偏高是什么原因| 睡不着觉去医院挂什么科| 推举是什么意思| 白内障有什么症状| 胆囊萎缩是什么原因| 龙和什么生肖最配| 焦糖色裤子配什么颜色上衣| 什么是穿刺| 刘禹锡是什么朝代的| 大脑供血不足吃什么药最好| 什么是普世价值| 儿童肚子疼挂什么科| 哗众取宠是什么意思| 1551是什么意思| 左侧小腹疼是什么原因| 化疗期间吃什么最好| 老板是什么意思| 断子绝孙是什么意思| rads是什么意思| 脉冲什么意思| 脚后跟麻木是什么原因| a型血和ab型血生的孩子是什么血型| 检查痛风挂什么科| 高净值什么意思| 乳头瘤是什么病| 京东自营是什么意思| 北方五行属什么| 唇炎看什么科室| 营救是什么意思| 什么动物跑得快| 养肝护肝吃什么好| 宝宝拉肚子有粘液是什么原因| 无名指长痣代表什么| 什么是眩晕症| 鸡尖是什么| innisfree是什么牌子的化妆品| AMY医学上是什么意思| 熊猫是什么科| 但愿是什么意思| 浑什么意思| 痢疾吃什么药效果最好| 浮沉是什么意思| 尿频是什么意思| 梦见别人给我介绍对象是什么意思| 一什么狼| 喉癌是什么原因引起的| 吃三七粉有什么作用| 同样的药为什么价格相差很多| b2b是什么| 查肝功能能查出什么病| 喜形于色是什么意思| 81年的鸡是什么命| 有张有弛是什么意思| 过期的洗面奶可以用来做什么| 喝完酒吃什么解酒最快| 苏格兰牧羊犬吃什么| 74年属什么的生肖| 感冒为什么不能吃鸡蛋| 喝咖啡对身体有什么好处| 人生观价值观世界观是什么意思| 甲状腺结节有什么症状表现| 癌抗原125是什么意思| 黄瓜炒什么菜好吃| 芙字五行属什么| 妇科支原体感染吃什么药| 尿很黄是什么原因| 毛囊是什么样子图片| 吃什么利尿最快| 什么水果治便秘| 周星驰是什么星座| 普拉提是什么| 同工同酬什么意思| 心无什么用| 什么的金边| 困惑是什么意思| 是什么拼音| 自给自足是什么意思| 番茄不能和什么一起吃| 什么的松脂| 脑供血不足做什么检查| 阴道有异味买什么药| 补血吃什么食物最好| 银屑病为什么会自愈| 孕妇血糖高可以吃什么水果| 过三关 是什么意思| 大姐大是什么意思| 推迟月经吃什么药| 三元及第是什么意思| 先天性心脏病是什么原因造成的| 祭坛是什么意思| ect是什么| 小儿惊痫是什么症状| 惊鸿是什么意思| 荔枝补什么| 尿隐血弱阳性是什么意思| 红枣什么时候吃最好| 介错是什么意思| 小乌龟吃什么| 山竹什么时候吃是应季| 迁单是什么意思| 贫血缺什么元素| 梦见牛是什么意思| 检查肝脏应该挂什么科| 缠腰蛇是什么症状图片| 尿道感染挂什么科| mixblu是什么牌子| 夏五行属什么| 男人脚肿是什么原因| sunny是什么意思| 龙的九个儿子都叫什么名字| 1月24日是什么星座| 养老院和敬老院有什么区别| 唐筛21三体临界风险是什么意思| 毛发旺盛女生什么原因引起的| 2021什么年| 属鼠和什么属相相冲| 菊花什么时候开花| 感冒吃什么药最快| 手指倒刺是什么原因| 白咖啡是什么| 烧心是什么感觉| 什么的拳头| 举头三尺有神明是什么意思| 梦到抓鱼是什么意思| 煮茶叶蛋用什么茶| 梦见狗是什么意思| 武则天原名叫什么| 什么病需要化疗| 蜂蜜和什么食物相克| 吃红苋菜有什么好处| 腋下有异味是什么原因| psp是什么| 精液是什么味道的| 窘迫是什么意思| 病人是什么生肖| 嘴唇是紫色的是什么原因| 温开水冲服是什么意思| 卫青为什么被灭九族| 中暑是什么症状表现| 肺部有问题一般会出现什么症状| 嘴苦是什么原因引起的| 茶油有什么功效| 贫血要注意些什么| 梦见被追杀是什么预兆| 安全期一般是什么时候| 反射弧太长是什么意思| 侄女结婚送什么礼物最好| 吃什么对肺有好处| 男人身体怕冷是什么原因如何调理| 用什么消肿最快| 蛋白粉适合什么人吃| 正餐是什么意思| 桂枝茯苓丸主治什么病| 烤油边是什么| 什么的爱| 2020是什么生肖| 午安是什么意思| 甲状腺是由什么引起的| 双子座上升星座是什么| 肝火旺吃什么中药| 凉拌菜用什么醋最好| 罹患是什么意思| 汉族人是什么人种| 脸过敏用什么药膏效果最好| 舟状腹见于什么疾病| 支气管挂什么科| 赵云的坐骑是什么马| 补牙为什么要分三次| 血液病是什么| 一味是什么意思| 什么是生育津贴| 农村补贴什么时候发放| 油烟机没有吸力是什么原因| 处长是什么级别| 卒中中心是什么意思| 淋巴结炎吃什么药| 刚怀孕要吃些什么好| 五七年属什么生肖| 卵泡破裂是什么意思| 茶卡是什么意思| 羊癫疯有什么症状表现| 早起的鸟儿有虫吃是什么意思| 一什么眼镜| 百度
Skip to content

add gelu and erf primitive operators for new autograd #45338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add gelu orig2prim rule
  • Loading branch information
cxxly committed Aug 23, 2022
commit c962563d517ab8bde1635e341a6ea3fc4f4cf887
45 changes: 45 additions & 0 deletions python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,5 +579,50 @@ def init_data(self):
self.out_map = {0: self.output['Out']}


class TestGeluOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'gelu'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')

self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'approximate': False}

self.orig2prim_args = (X, )
self.all_ops = [
'gelu', 'add_p', 'erf_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'mul_p', 'mul_p', 'mul_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim):

def init_data(self):
self.op_type = 'gelu'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')

self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'approximate': True}

self.orig2prim_args = (X, )
self.all_ops = [
'add_p', 'add_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'fill_constant_p', 'fill_constant_p', 'gelu',
'mul_p', 'mul_p', 'mul_p', 'mul_p', 'pow_p', 'tanh_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


if __name__ == '__main__':
unittest.main()
41 changes: 29 additions & 12 deletions python/paddle/fluid/tests/unittests/autograd/test_primapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def test_illegal_param(self):
np.array([2, 2, 2]),
), None, 'float32'),
('erf', paddle.erf, (np.random.rand(300, 288), ), None, 'float32'),
('gelu', paddle.nn.functional.gelu,
(np.random.rand(200, 189), ), None, 'float32'),
('gelu_approximate', lambda x: paddle.nn.functional.gelu(x, True),
(np.random.rand(200, 189), ), None, 'float32'),
))
class TestGrad(unittest.TestCase):

Expand Down Expand Up @@ -406,20 +410,33 @@ def multiply_pd(x):
erf_ag = lambda xs: ascipy.special.erf(xs[0])


def gelu_ag(x, approximate=False):
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + anp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x**3))))
return x * cdf
else:
return x * (ascipy.special.erf(x / np.sqrt(2)) + 1) / 2


@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'), (
('multiply', multiply_pd, multiply_ag,
(np.random.rand(3, 5), ), None, 'float32'),
('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'),
('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'),
('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'),
('pow', paddle.pow, pow_ag,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'),
('erf', paddle.erf, erf_ag,
(np.random.rand(100, 200), ), None, 'float32'),
))
(utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'),
(('multiply', multiply_pd, multiply_ag,
(np.random.rand(3, 5), ), None, 'float32'),
('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'),
('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'),
('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'),
('pow', paddle.pow, pow_ag,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'),
('erf', paddle.erf, erf_ag, (np.random.rand(100, 200), ), None, 'float32'),
('gelu', paddle.nn.functional.gelu, lambda xs: gelu_ag(xs[0]),
(np.random.rand(10, 20, 30), ), None, 'float32'),
('gelu_approximate',
lambda x: paddle.nn.functional.gelu(x, approximate=True),
lambda xs: gelu_ag(xs[0], approximate=True),
(np.random.rand(10, 20, 30), ), None, 'float32')))
class TestGradWithHigherOrder(unittest.TestCase):

def setUp(self):
Expand Down
25 changes: 23 additions & 2 deletions python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,34 @@ def elementwise_pow_orig2prim(op, x, y):
def elementwise_max_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)

return primops.max(x, y)


## Register prim2orig lower rules
@REGISTER_ORIG2PRIM('gelu')
def gelu_orig2prim(op, x):
if op.attr('approximate'):
cdf = mul(
fill_const(0.5, x.shape, x.dtype),
add(
fill_const(1.0, x.shape, x.dtype),
tanh(
mul(
fill_const(math.sqrt(2 / math.pi), x.shape, x.dtype),
add(
x,
mul(
fill_const(0.044715, x.shape, x.dtype),
primops.pow(x, fill_const(3., x.shape,
x.dtype))))))))
return mul(x, cdf)
else:
return mul(
mul(fill_const(0.5, x.shape, x.dtype), x),
add(fill_const(1.0, x.shape, x.dtype),
erf(mul(x, fill_const(1 / math.sqrt(2.), x.shape, x.dtype)))))


## Register prim2orig lower rules
@REGISTER_PRIM2ORIG('add_p')
def add_prim2orig(op, x, y):
return paddle.add(x, y)
Expand Down
老打瞌睡犯困是什么原因 尿道刺痛吃什么药 政客是什么意思 胸腔疼痛是什么原因 肝经湿热吃什么中成药
剖腹产第四天可以吃什么 耵聍是什么 知了喜欢吃什么 大小便失禁是什么意思 上不下大是什么字
带状疱疹后遗神经痛用什么药 cosplay是什么 相性是什么意思 衣食无忧是什么生肖 喝红花有什么作用与功效
梦到和死人说话是什么意思 尿素是什么 衣字旁的字和什么有关 翡翠跟玉有什么区别 脂肪瘤看什么科
黄字五行属什么hcv8jop5ns2r.cn 慢阻肺是什么原因引起的hcv9jop2ns3r.cn hpv16有什么症状hcv7jop5ns2r.cn 身上臭是什么原因hcv8jop3ns4r.cn 五行海中金是什么意思hcv9jop4ns4r.cn
喝酒拉肚子是什么原因hcv7jop6ns1r.cn 什么是肺大泡hcv9jop5ns3r.cn 在什么情况下需要做肠镜hcv9jop2ns1r.cn 什么事的英文hcv9jop5ns0r.cn 枕芯用什么填充物好hcv8jop8ns7r.cn
拍黄瓜是什么意思hcv8jop0ns6r.cn 肝裂不宽是什么意思luyiluode.com 性价比高什么意思hcv7jop5ns2r.cn 团购是什么意思hcv8jop8ns9r.cn 马来酸曲美布汀片什么时候吃hcv9jop3ns6r.cn
头晕应该挂什么科hcv8jop8ns8r.cn KT是什么xianpinbao.com rt是什么hcv9jop5ns2r.cn 总胆固醇高是什么原因hcv9jop2ns7r.cn 灵枢是什么意思hcv8jop0ns4r.cn
百度