rasa_core: nlg模块源码解读

最近在学习使用rasa构建聊天机器人,为了实现一个比较特别的功能,需要搞懂源码。rasa 的代码质量相当高,注释完整,函数定义包含 type hint 读起来非常舒服。
rasa_core.nlg模块包含5个py脚本:

  • __init__.py
  • callback.py
  • generator.py
  • interpolator.py
  • template.py

首先看 __init__.py

from rasa.core.nlg.generator import NaturalLanguageGenerator
from rasa.core.nlg.template import TemplatedNaturalLanguageGenerator
from rasa.core.nlg.callback import CallbackNaturalLanguageGenerator

可以看到,nlg模块主要有三个类,

  • NaturalLanguageGenerator(NLG)
  • TemplatedNaturalLanguageGenerator(TNLG)
  • CallbackNaturalLanguageGenerator(CNLG)

TNLGCNLG都继承自NLG,所以从NLG开始。

NaturalLanguageGenerator

NLG类包含两个成员函数:

  • generate
  • create
    generate是抽象函数,没有具体实现,create是静态函数。
generate:
async def generate(
    self,
    template_name: Text,
    tracker: "DialogueStateTracker",
    output_channel: Text,
    **kwargs: Any,
) -> Optional[Dict[Text, Any]]

异步抽象函数,用于对用户输入产生回复。

create
@staticmethod
def create(
    obj: Union["NaturalLanguageGenerator", EndpointConfig, None],
    domain: Optional[Domain],
) -> "NaturalLanguageGenerator":
    """Factory to create a generator."""

    if isinstance(obj, NaturalLanguageGenerator):
        return obj
    else:
        return _create_from_endpoint_config(obj, domain)

静态函数,用于产生一个NLG实例。建议的输入obj是NLG实例或者EndpointConfig对象,domain是Domain对象,如果obj是NLG实例,直接返回obj,否则根据EndpointConfig和Domain的配置,借助了_create_from_endpoint_config函数,实例化一个NLG。

_create_from_endpoint_config

接下来,我们来看_create_from_endpoint_config这个函数。

def _create_from_endpoint_config(
    endpoint_config: Optional[EndpointConfig] = None, domain: Optional[Domain] = None,
) -> "NaturalLanguageGenerator":
    """Given an endpoint configuration, create a proper NLG object."""

    domain = domain or Domain.empty()

    if endpoint_config is None:
        from rasa.core.nlg import (  # pytype: disable=pyi-error
            TemplatedNaturalLanguageGenerator,
        )

        # this is the default type if no endpoint config is set
        nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    elif endpoint_config.type is None or endpoint_config.type.lower() == "callback":
        from rasa.core.nlg import (  # pytype: disable=pyi-error
            CallbackNaturalLanguageGenerator,
        )

        # this is the default type if no nlg type is set
        nlg = CallbackNaturalLanguageGenerator(endpoint_config=endpoint_config)
    elif endpoint_config.type.lower() == "template":
        from rasa.core.nlg import (  # pytype: disable=pyi-error
            TemplatedNaturalLanguageGenerator,
        )

        nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    else:
        nlg = _load_from_module_string(endpoint_config, domain)

    logger.debug(f"Instantiated NLG to '{nlg.__class__.__name__}'.")
    return nlg

_create_from_endpoint_config的输入同样是EndpointConfig对象和Domain对象。函数主体是if-else的结构,根据EndpointConfig的状况决定构建怎样的NLG实例。

_load_from_module_string
def _load_from_module_string(
    endpoint_config: EndpointConfig, domain: Domain
) -> "NaturalLanguageGenerator":
    """Initializes a custom natural language generator.

    Args:
        domain: defines the universe in which the assistant operates
        endpoint_config: the specific natural language generator
    """

    try:
        nlg_class = common.class_from_module_path(endpoint_config.type)
        return nlg_class(endpoint_config=endpoint_config, domain=domain)
    except (AttributeError, ImportError) as e:
        raise Exception(
            f"Could not find a class based on the module path "
            f"'{endpoint_config.type}'. Failed to create a "
            f"`NaturalLanguageGenerator` instance. Error: {e}"
        )

TemplatedNaturalLanguageGenerator

TNLG继承自NLG,除了NLG的成员函数之外,还有以下新成员:

  • _templates_for_utter_action
  • _random_template_for
  • generate
  • generate_from_slots
  • _fill_template
  • _template_variables
    首先来看最重要的generate
generate
async def generate(
    self,
    template_name: Text,
    tracker: DialogueStateTracker,
    output_channel: Text,
    **kwargs: Any,
) -> Optional[Dict[Text, Any]]:
    """Generate a response for the requested template."""

    filled_slots = tracker.current_slot_values()
    return self.generate_from_slots(
        template_name, filled_slots, output_channel, **kwargs
    )

输入是模板名和tracker对象,在模板中填充tracker记录的槽位生成回复语句。生成语句这里调用的是generate_from_slots函数。

generate_from_slots
def generate_from_slots(
    self,
    template_name: Text,
    filled_slots: Dict[Text, Any],
    output_channel: Text,
    **kwargs: Any,
) -> Optional[Dict[Text, Any]]:
    """Generate a response for the requested template."""

    # Fetching a random template for the passed template name
    r = copy.deepcopy(self._random_template_for(template_name, output_channel))
    # Filling the slots in the template and returning the template
    if r is not None:
        return self._fill_template(r, filled_slots, **kwargs)
    else:
        return None

这里调用_random_template_for随机选择模板(一个action可能对应多个回复模板),然后调用_fill_template填充模板中的槽位。
先来看_random_template_for。

_random_template_for
def _random_template_for(
    self, utter_action: Text, output_channel: Text
) -> Optional[Dict[Text, Any]]:
    """Select random template for the utter action from available ones.

    If channel-specific templates for the current output channel are given,
    only choose from channel-specific ones.
    """
    import numpy as np

    if utter_action in self.templates:
        suitable_templates = self._templates_for_utter_action(
            utter_action, output_channel
        )

        if suitable_templates:
            return np.random.choice(suitable_templates)
        else:
            return None
    else:
        return None

调用_templates_for_utter_action函数拿到当前action的所有模板,使用np.random.choice在模板列表中随机选择一个。可以看到,输入是action名,返回的template其实是一个 dict 对象。

_fill_template

_fill_template将对选择的模板进行槽位填充的工作。

def _fill_template(
    self,
    template: Dict[Text, Any],
    filled_slots: Optional[Dict[Text, Any]] = None,
    **kwargs: Any,
) -> Dict[Text, Any]:
    """"Combine slot values and key word arguments to fill templates."""

    # Getting the slot values in the template variables
    template_vars = self._template_variables(filled_slots, kwargs)

    keys_to_interpolate = [
        "text",
        "image",
        "custom",
        "button",
        "attachment",
        "quick_replies",
    ]
    if template_vars:
        for key in keys_to_interpolate:
            if key in template:
                template[key] = interpolate(template[key], template_vars)
    return template

可以看到,输入的模板template和填充槽位filled_slots都是dict对象。暂时没有看到具体的例子,猜测:
filled_slots中的所有key都是template中的槽位名,value是对槽位的填充值value,通过替换template中的槽位填充值,完成回复语句的生成。

interpolate.py

在实现TNLG的回复生成阶段,调用了interpolate.py下的两个模块 interpolate和interpolate_text。interpolate_text用于对text格式的template进行槽位填充,使用正则表达式替换和str.format()的形式:

def interpolate_text(template: Text, values: Dict[Text, Text]) -> Text:
    # transforming template tags from
    # "{tag_name}" to "{0[tag_name]}"
    # as described here:
    # https://stackoverflow.com/questions/7934620/python-dots-in-the-name-of-variable-in-a-format-string#comment9695339_7934969
    # black list character and make sure to not to allow
    # (a) newline in slot name
    # (b) { or } in slot name
    try:
        text = re.sub(r"{([^\n{}]+?)}", r"{0[\1]}", template)
        text = text.format(values)
        if "0[" in text:
            # regex replaced tag but format did not replace
            # likely cause would be that tag name was enclosed
            # in double curly and format func simply escaped it.
            # we don't want to return {0[SLOTNAME]} thus
            # restoring original value with { being escaped.
            return template.format({})

        return text
    except KeyError as e:
        logger.exception(
            "Failed to fill utterance template '{}'. "
            "Tried to replace '{}' but could not find "
            "a value for it. There is no slot with this "
            "name nor did you pass the value explicitly "
            "when calling the template. Return template "
            "without filling the template. "
            "".format(template, e.args[0])
        )
        return template

CallbackNaturalLanguageGenerator

最后,来看CNLG。CNLG的结构要简单很多,仅包含两个成员函数,一个产生回复的generate,另一个用于检验回复格式是否合法的validate_response。

generate
async def generate(
    self,
    template_name: Text,
    tracker: DialogueStateTracker,
    output_channel: Text,
    **kwargs: Any,
) -> Dict[Text, Any]:
    """Retrieve a named template from the domain using an endpoint."""

    body = nlg_request_format(template_name, tracker, output_channel, **kwargs)

    logger.debug(
        "Requesting NLG for {} from {}."
        "".format(template_name, self.nlg_endpoint.url)
    )

    response = await self.nlg_endpoint.request(
        method="post", json=body, timeout=DEFAULT_REQUEST_TIMEOUT
    )

    if self.validate_response(response):
        return response
    else:
        raise Exception("NLG web endpoint returned an invalid response.")

输入是action的名称,用于记录的tracker,以及output_channel。首先从nlg_request_format函数中得到request的body,之后向endpoint上的服务发出请求,调用定义在对应Action类中的run函数,得到response,验证response的合法性,并且返回。

nlg_request_format
def nlg_request_format(
    template_name: Text,
    tracker: DialogueStateTracker,
    output_channel: Text,
    **kwargs: Any,
) -> Dict[Text, Any]:
    """Create the json body for the NLG json body for the request."""

    tracker_state = tracker.current_state(EventVerbosity.ALL)

    return {
        "template": template_name,
        "arguments": kwargs,
        "tracker": tracker_state,
        "channel": {"name": output_channel},
    }

这个函数处理产生request的主体,用于指定Action的调用。在写Action的时候就很好奇,Action类的run函数一般定义成这样:def run(self, dispatcher, tracker, domain),后来就很神奇的发现这里边的tracker并不是一个rasa_core.trackers,包含的信息比较少。果然,这里产生的tracker,仅仅保留了当前状态。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,716评论 4 364
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,558评论 1 294
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,431评论 0 244
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,127评论 0 209
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,511评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,692评论 1 222
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,915评论 2 313
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,664评论 0 202
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,412评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,616评论 2 245
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,105评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,424评论 2 254
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,098评论 3 238
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,096评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,869评论 0 197
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,748评论 2 276
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,641评论 2 271

推荐阅读更多精彩内容

  • 一、Rasa Rasa是一个开源机器学习框架,用于构建上下文AI助手和聊天机器人。Rasa有两个主要模块: Ras...
    风玲儿阅读 51,823评论 1 30
  • [TOC] Rasa学习笔记2--Rasa Core 1. 概念介绍 首先引出Rasa的设计理念:Learning...
    ColdCoder阅读 2,949评论 1 4
  • 模板标签除了几个常用的,还真心没有仔细了解一下,看到2.0发布后,翻译学习一下。 本文尽量忠实原著,毕竟大神的东西...
    海明_fd17阅读 1,918评论 0 5
  • Linux基础入门第二节实验报告 1、重要快捷键: 【Tab】 好处就是当你忘记某个命令的全称时可以只输入它的开头...
    小公举凡阅读 238评论 0 1
  • 从我的世界里逃出来 多想 进入另一个人的世界 期待着 与他 不期而遇 沿着他的路径 再走一遍 当年的情景 期待着 ...
    青果未熟阅读 581评论 0 3