diff --git a/MANIFEST.in b/MANIFEST.in index e6294c7d5..ab80e86b6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,3 +2,4 @@ recursive-include swift/utils *.py recursive-include swift/llm/data *.* recursive-include swift/llm/ds_config *.json recursive-include requirements *.txt +recursive-include swift/llm/agent *.json diff --git a/docs/source/LLM/Agent部署最佳实践.md b/docs/source/LLM/Agent部署最佳实践.md index cd4a218a6..7f4c34866 100644 --- a/docs/source/LLM/Agent部署最佳实践.md +++ b/docs/source/LLM/Agent部署最佳实践.md @@ -252,6 +252,72 @@ curl -X POST http://localhost:8000/v1/chat/completions \ 在返回结果的tool_calls中,可以获得调用的函数以及参数信息。 +你也可以通过OpenAI SDK进行测试 +```python +from openai import OpenAI +client = OpenAI( + api_key='EMPTY', + base_url='http://localhost:8000/v1', +) +query = "What's the weather like in Boston today?" +messages = [{ + 'role': 'user', + 'content': query +}] +tools = [ + { + "name": "url_for_newapi", + "description": "This is the subfunction for tool \"newapi\", you can use this tool.The description of this function is: \"url_for_newapi\"", + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "", + "example_value": "https://www.instagram.com/reels/CtB6vWMMHFD/" + } + }, + "required": [ + "url" + ], + "optional": [ + "url" + ] + } + }, +] +resp = client.chat.completions.create( + model='llama3-8b-instruct', + tools = tools, + messages=messages, + seed=42) +tool_calls = resp.choices[0].message.tool_calls +print(f'query: {query}') +print(f'tool_calls: {tool_calls}') + +# 流式 +stream_resp = client.chat.completions.create( + model='llama3-8b-instruct', + messages=messages, + tools=tools, + stream=True, + seed=42) + +print(f'query: {query}') +print('response: ', end='') +for chunk in stream_resp: + print(chunk.choices[0].delta.content, end='', flush=True) +print() + +""" +query: What's the weather like in Boston today? +tool_calls: {'id': 'toolcall-e4c637435e754cf9b2034c3e6861a4ad', 'function': {'arguments': ' {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}', 'name': 'url_for_newapi'}, 'type': 'function'} +query: What's the weather like in Boston today? +response: Thought: I need to find the weather information for Boston today. I can use the 'newapi' tool to get the weather forecast. +Action: url_for_newapi +Action Input: {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"} +""" +``` 假设调用返回的结果为`The weather in Boston today is 32°F (0°C), with clear skies`, 我们将结果在role tool字段填入message传入 ```shell curl -X POST http://localhost:8000/v1/chat/completions \ diff --git a/docs/source/LLM/命令行参数.md b/docs/source/LLM/命令行参数.md index 42a23ff6b..b74c7dba9 100644 --- a/docs/source/LLM/命令行参数.md +++ b/docs/source/LLM/命令行参数.md @@ -124,6 +124,11 @@ - `--train_dataset_mix_ratio`: 默认为`0.`. 该参数定义了如何进行数据集打混训练. 指定该参数时, 会混合训练集的`train_dataset_mix_ratio`倍数的`train_dataset_mix_ds`指定的通用知识数据集. 该参数已废弃, 请使用`--dataset`进行数据集混合. - `--train_dataset_mix_ds`: 默认为`['ms-bench']`. 用于防止知识遗忘的通用知识数据集. 该参数已废弃, 请使用`--dataset`进行数据集混合. - `--use_loss_scale`: 默认为`False`. 生效时会将Agent的部分字段(Action/Action Input部分)的loss权重加强以强化CoT, 对普通SFT场景没有任何效果. +- `--loss_scale_config_path` 选项指定自定义的 loss_scale 配置,适用于在启用 use_loss_scale 时,例如在 Agent 训练中放大 Action 和其他关键 ReAct 字段的损失权重。 + - 在配置文件中,您可以使用字典格式来设置 loss_scale。每个键代表一个特定字段名,其关联的值设定了该字段及其后续内容的损失缩放倍数。例如,通过设定 `"Observation:": [2, 0]`,当response包含 `xxxx Observation:error` 时,`Observation:` 字段loss将增加到两倍,`error` 部分的loss则不计入。除了字面匹配,配置也支持正则表达式规则,以实现更灵活的匹配,如模式 '<.*?>':[2.0] 将针对所有尖括号括起来的部分损失增加到两倍。字段匹配与正则匹配所对应的损失缩放倍数,分别由长度为2和1的列表表示。 + - 同时支持匹配query对整段response设置loss_scale, 这在处理像[Agent-FLAN](https://arxiv.org/abs/2403.12881)论文中描述的固定多轮对话查询时极其有用,如果query中包含了预定义键的任一项,相应的响应将采用关联的 loss_scale 值。,你可以参考`swift/llm/agent/agentflan.json` + - 默认情况下,我们为 Action:, Action Input:, Thought:, Final Answer:, 和 Observation: 等字段预设了损失缩放值。我们为[alpha-umi](https://arxiv.org/pdf/2401.07324)和[Agent-FLAN](https://arxiv.org/abs/2403.12881)也提供了默认配置,你可以设置为`alpha-umi`和`agent-flan`来使用。默认的配置文件位于`swift/llm/agent`下 + - 匹配规则的应用优先级,从高到低为:query字段 > response特定字段 > 正则表达式匹配规则。 - `--custom_register_path`: 默认为`None`. 传入`.py`文件, 用于注册模板、模型和数据集. - `--custom_dataset_info`: 默认为`None`, 传入外置dataset_info.json的路径、json字符串或者dict. 用于拓展数据集. 格式参考: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json - `--device_map_config_path`: 从本地文件中手动配置模型的device_map, 默认为None diff --git a/docs/source_en/LLM/Agent-deployment-best-practice.md b/docs/source_en/LLM/Agent-deployment-best-practice.md index 9bdac6852..80d9516ee 100644 --- a/docs/source_en/LLM/Agent-deployment-best-practice.md +++ b/docs/source_en/LLM/Agent-deployment-best-practice.md @@ -251,6 +251,73 @@ result {"model":"llama3-8b-instruct","choices":[[{"index":0,"message":{"role":"assistant","content":"Question: What's the weather like in Boston today?\n\nThought: I need to get the current weather in Boston to answer this question.\n\nAction: get_current_weather\n\nAction Input: {'location': 'Boston, MA', 'unit': 'fahrenheit'}\n\nObservation:","tool_calls":{"id":"toolcall-f534d907ae254f2ab96e06c25179ddf9","function":{"arguments":" {'location': 'Boston, MA', 'unit': 'fahrenheit'}\n\n","name":"get_current_weather"},"type":"function"}},"finish_reason":"stop"}]],"usage":{"prompt_tokens":262,"completion_tokens":54,"total_tokens":316},"id":"chatcmpl-8630e8d675c941c0aca958a37633a3c9","object":"chat.completion","created":1717590756} ``` +You can also test with OpenAI SDK, for example +```python +from openai import OpenAI +client = OpenAI( + api_key='EMPTY', + base_url='http://localhost:8000/v1', +) +query = "What's the weather like in Boston today?" +messages = [{ + 'role': 'user', + 'content': query +}] +tools = [ + { + "name": "url_for_newapi", + "description": "This is the subfunction for tool \"newapi\", you can use this tool.The description of this function is: \"url_for_newapi\"", + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "", + "example_value": "https://www.instagram.com/reels/CtB6vWMMHFD/" + } + }, + "required": [ + "url" + ], + "optional": [ + "url" + ] + } + }, +] +resp = client.chat.completions.create( + model='llama3-8b-instruct', + tools = tools, + messages=messages, + seed=42) +tool_calls = resp.choices[0].message.tool_calls +print(f'query: {query}') +print(f'tool_calls: {tool_calls}') + +# stream +stream_resp = client.chat.completions.create( + model='llama3-8b-instruct', + messages=messages, + tools=tools, + stream=True, + seed=42) + +print(f'query: {query}') +print('response: ', end='') +for chunk in stream_resp: + print(chunk.choices[0].delta.content, end='', flush=True) +print() + +""" +query: What's the weather like in Boston today? +tool_calls: {'id': 'toolcall-e4c637435e754cf9b2034c3e6861a4ad', 'function': {'arguments': ' {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}', 'name': 'url_for_newapi'}, 'type': 'function'} +query: What's the weather like in Boston today? +response: Thought: I need to find the weather information for Boston today. I can use the 'newapi' tool to get the weather forecast. +Action: url_for_newapi +Action Input: {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"} +""" +``` + In the tool_calls of the returned results, you can obtain the information about the called function and its parameters. Assuming the returned result is `The weather in Boston today is 32°F (0°C), with clear skies`, we will fill this result into the role as tool and pass it into the message field. diff --git a/docs/source_en/LLM/Command-line-parameters.md b/docs/source_en/LLM/Command-line-parameters.md index 82a2e1020..3eea5c4a9 100644 --- a/docs/source_en/LLM/Command-line-parameters.md +++ b/docs/source_en/LLM/Command-line-parameters.md @@ -125,6 +125,11 @@ - `--train_dataset_mix_ratio`: Default is `0.`. This parameter defines how to mix datasets for training. When this parameter is specified, it will mix the training dataset with a multiple of `train_dataset_mix_ratio` of the general knowledge dataset specified by `train_dataset_mix_ds`. This parameter has been deprecated, please use `--dataset {dataset_name}#{dataset_sample}` to mix datasets. - `--train_dataset_mix_ds`: Default is `['ms-bench']`. Used for preventing knowledge forgetting, this is the general knowledge dataset. This parameter has been deprecated, please use `--dataset {dataset_name}#{dataset_sample}` to mix datasets. - `--use_loss_scale`: Default is `False`. When taking effect, strengthens loss weight of some Agent fields (Action/Action Input part) to enhance CoT, has no effect in regular SFT scenarios. +- `loss_scale_config_path`: option specifies a custom loss_scale configuration, applicable when use_loss_scale is enabled, such as in Agent training to amplify the loss weights for Action and other crucial ReAct fields. + - In the configuration file, you can set the loss_scale using a dictionary format. Each key represents a specific field name, and its associated value specifies the loss scaling factor for that field and its subsequent content. For instance, setting `"Observation:": [2, 0]` means that when the response contains `xxxx Observation:error`, the loss for the `Observation:` field will be doubled, while the loss for the `error` portion will not be counted. Besides literal matching, the configuration also supports regular expression rules for more flexible matching; for example, the pattern `'<.*?>':[2.0]` doubles the loss for any content enclosed in angle brackets. The loss scaling factors for field matching and regex matching are respectively indicated by lists of length 2 and 1. + - There is also support for setting loss_scale for the entire response based on matching queries, which is extremely useful in dealing with fixed multi-turn dialogue queries described in the [Agent-Flan paper](https://arxiv.org/abs/2403.12881) paper. If the query includes any of the predefined keys, the corresponding response will use the associated loss_scale value. Refer to swift/llm/agent/agentflan.json for an example. + - By default, we have preset loss scaling values for fields such as Action:, Action Input:, Thought:, Final Answer:, and Observation:. We also provide default configurations for [alpha-umi](https://arxiv.org/pdf/2401.07324) and [Agent-FLAN](https://arxiv.org/abs/2403.12881), which you can use by setting to alpha-umi and agent-flan respectively. The default configuration files are located under swift/llm/agent. + - The application priority of matching rules is as follows, from highest to lowest: query fields > specific response fields > regular expression matching rules. - `--custom_register_path`: Default is `None`. Pass in a `.py` file used to register templates, models, and datasets. - `--custom_dataset_info`: Default is `None`. Pass in the path to an external `dataset_info.json`, a JSON string, or a dictionary. Used to register custom datasets. The format example: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json - `device_map_config_path`: Manually configure the model's device map from a local file, defaults to None. diff --git a/swift/llm/agent/agentflan.json b/swift/llm/agent/agentflan.json new file mode 100644 index 000000000..2751fea02 --- /dev/null +++ b/swift/llm/agent/agentflan.json @@ -0,0 +1,22 @@ +{ + "response":{ + "Name:": [1.0, 3.0], + "Action:": [1.0, 3.0], + "ACTION:": [1.0,3.0], + "Tool:": [1.0, 3.0], + "Command": [1.0, 3.0], + "Arguments:": [1.0, 3.0], + "action input": [1.0, 3.0], + "ACTION_INPUT:":[1.0, 3.0], + "Action Input:": [1.0, 3.0], + "Thought:": [1.0, 1.0], + "Final Answer:": [1.0, 1.0], + "Observation:": [2.0, 0.0] + }, + "query":{ + "What is the tool you want to use": [3.0], + "What are the required parameter names": [3.0], + "What is the value of": [3.0], + "What are the required parameter names for this tool": [3.0] + } +} diff --git a/swift/llm/agent/alpha_umi_loss_scale_config.json b/swift/llm/agent/alpha_umi_loss_scale_config.json new file mode 100644 index 000000000..fcdcbcb18 --- /dev/null +++ b/swift/llm/agent/alpha_umi_loss_scale_config.json @@ -0,0 +1,8 @@ +{ + "Action:": [2.0, 2.0], + "Action Input:": [2.0, 2.0], + "Thought:": [1.0, 1.0], + "Final Answer:": [1.0, 1.0], + "Observation:": [2.0, 0.0], + "Next:": [2,0, 2.0] +} diff --git a/swift/llm/agent/default_loss_scale_config.json b/swift/llm/agent/default_loss_scale_config.json new file mode 100644 index 000000000..006f92948 --- /dev/null +++ b/swift/llm/agent/default_loss_scale_config.json @@ -0,0 +1,7 @@ +{ + "Action:": [2.0, 2.0], + "Action Input:": [2.0, 2.0], + "Thought:": [1.0, 1.0], + "Final Answer:": [1.0, 1.0], + "Observation:": [2.0, 0.0] +} diff --git a/swift/llm/agent/utils.py b/swift/llm/agent/utils.py index 628529e94..a0265d3a0 100644 --- a/swift/llm/agent/utils.py +++ b/swift/llm/agent/utils.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from swift.utils import get_logger from swift.utils.utils import split_str_parts_by @@ -65,7 +65,11 @@ use function Finish->give_up_and_restart. Specifically, you have access to the following APIs: {tool_list}''' -def calculate_loss_scale(response: str, use_loss_scale=False) -> Tuple[List[str], List[float]]: +def calculate_loss_scale(query: str, + response: str, + use_loss_scale=False, + response_loss_scale_map: Optional[dict[str, list]] = None, + query_loss_scale_map: Optional[dict[str, list]] = None) -> Tuple[List[str], List[float]]: """Calculate the loss scale by splitting the agent response. This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf @@ -90,41 +94,35 @@ def calculate_loss_scale(response: str, use_loss_scale=False) -> Tuple[List[str] Returns: A tuple of agent response parts and their weights. """ - if 'Action:' in response and 'Observation:' in response and use_loss_scale: - agent_keyword = ['Action:', 'Action Input:', 'Thought:', 'Final Answer:', 'Observation:'] - agent_parts = split_str_parts_by(response, agent_keyword) + if use_loss_scale: + # query loss scale map + if query_loss_scale_map is not None: + for key in query_loss_scale_map.keys(): + if key in query: + if isinstance(query_loss_scale_map[key], (float, int)): + query_loss_scale_map[key] = [query_loss_scale_map[key]] + loss_scale_value = query_loss_scale_map[key][0] + return [response], [float(loss_scale_value)] + delimiters = list(k for k in response_loss_scale_map.keys() if len(response_loss_scale_map[k]) == 2) + agent_parts = split_str_parts_by(response, delimiters) + regex_delimiters = {k: v for k, v in response_loss_scale_map.items() if len(v) == 1} + if len(regex_delimiters): + agent_parts = split_parts_by_regex(agent_parts, regex_delimiters) weights = [] agent_content = [] for c in agent_parts: - if c['key'] in ('Action:', 'Action Input:'): - weights += [2.0] - weights += [2.0] - elif c['key'] in ('Thought:', 'Final Answer:', ''): - weights += [1.0] - weights += [1.0] - elif c['key'] in ('Observation:', ): - weights += [2.0] - weights += [0.0] - agent_content.append(c['key']) - agent_content.append(c['content']) - return agent_content, weights - elif ('Action:' in response or 'Next:' in response) and use_loss_scale: # alpha-umi - agent_keyword = ['Next:', 'Action:', 'Action Input:'] - agent_parts = split_str_parts_by(response, agent_keyword) - weights = [] - agent_content = [] - for c in agent_parts: - if c['key'] in ('Action:', 'Action Input:', 'Next:'): - weights += [2.0] - weights += [2.0] - elif c['key'] in ('Thought:', 'Final Answer:', ''): - weights += [1.0] - weights += [1.0] - elif c['key'] in ('Observation:', ): - weights += [2.0] - weights += [0.0] - agent_content.append(c['key']) - agent_content.append(c['content']) + if isinstance(c['key'], (float, int)): + weights += [c['key']] + agent_content.append(c['content']) + else: + if c['key'] in response_loss_scale_map: + weights += [response_loss_scale_map[c['key']][0]] + weights += [response_loss_scale_map[c['key']][1]] + agent_content.append(c['key']) + agent_content.append(c['content']) + else: + weights += [1.0] + agent_content.append(c['content']) return agent_content, weights else: return [response], [1.0] @@ -150,6 +148,31 @@ def split_action_action_input(response: str) -> Tuple[Optional[str], Optional[st return action, action_input +def split_parts_by_regex(text_list: list, regex_delimiters: Dict[str, List[float]]): + import re + compiled_patterns = [(re.compile(pattern), scale) for pattern, scale in regex_delimiters.items()] + for i in range(len(text_list) - 1, -1, -1): + item = text_list[i] + if item.get('key') == '': + res_text = item['content'] + last_idx = 0 + segments = [] + + for pattern, scale in compiled_patterns: + matches = list(re.finditer(pattern, res_text)) + for match in matches: + if match.start() > last_idx: + segments.append({'key': '', 'content': res_text[last_idx:match.start()]}) + segments.append({'key': scale[0], 'content': match.group(0)}) + last_idx = match.end() + + if last_idx < len(res_text): + segments.insert(0, {'key': '', 'content': res_text[last_idx:]}) + + if segments: + text_list[i:i + 1] = segments + + def get_tools_prompt(TOOLS: list[dict[str, Union[str, dict]]], prompt_format: str = 'react_en') -> Optional[str]: tool_descs = [] tool_names = [] diff --git a/swift/llm/data/dataset_info.json b/swift/llm/data/dataset_info.json index edcb38ee6..87bf0075a 100644 --- a/swift/llm/data/dataset_info.json +++ b/swift/llm/data/dataset_info.json @@ -162,6 +162,13 @@ }, "tags": ["chat", "agent", "multi-round", "🔥"] }, + "msagent-pro": { + "dataset_id": "iic/MSAgent-Pro", + "conversations": { + "error_strategy": "delete" + }, + "tags": ["chat", "agent", "multi-round", "🔥"] + }, "codefuse-python-en": { "dataset_id": "codefuse-ai/CodeExercise-Python-27k", "conversations": { diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 1e213c6f9..5bf529187 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -197,6 +197,12 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]: if use_model: template_kwargs['model'] = model template_kwargs['use_loss_scale'] = args.use_loss_scale + if args.loss_scale_config_path is not None: + cwd = os.getcwd() + config_path = args.loss_scale_config_path if os.path.isabs(args.loss_scale_config_path) else os.path.join( + cwd, args.loss_scale_config_path) + with open(config_path, 'r') as json_file: + template_kwargs['loss_scale_map'] = json.load(json_file) template_kwargs['tools_prompt'] = args.tools_prompt if args.sequence_parallel_size and args.sequence_parallel_size > 1: template_kwargs['sequence_parallel_size'] = args.sequence_parallel_size diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 744f2fcbf..ac9225d47 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -461,6 +461,7 @@ class SftArguments(ArgumentsBase): dataset_seed: int = 42 dataset_test_ratio: float = 0.01 use_loss_scale: bool = False # for agent + loss_scale_config_path: str = 'DEFAULT' system: Optional[str] = None tools_prompt: Literal['react_en', 'react_zh', 'toolbench'] = 'react_en' max_length: int = 2048 # -1: no limit @@ -753,7 +754,16 @@ class SftArguments(ArgumentsBase): if self.deepspeed == ds_name: self.deepspeed = os.path.join(ds_config_folder, ds_config) break - + if self.loss_scale_config_path: + if self.loss_scale_config_path == 'DEFAULT': + self.loss_scale_config_path = os.path.abspath( + os.path.join(__file__, '..', '..', 'agent', 'default_loss_scale_config.json')) + elif self.loss_scale_config_path == 'alpha-umi': # https://arxiv.org/pdf/2401.07324 + self.loss_scale_config_path = os.path.abspath( + os.path.join(__file__, '..', '..', 'agent', 'alpha_umi_loss_scale_config.json')) + elif self.loss_scale_config_path == 'agent-flan': # https://arxiv.org/abs/2403.12881 + self.loss_scale_config_path = os.path.abspath( + os.path.join(__file__, '..', '..', 'agent', 'agentflan.json')) self.handle_path() self._handle_dataset_sample() self._register_self_cognition() diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index b4eafd919..3da5a7e70 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -72,6 +72,8 @@ class DatasetName: damo_agent_zh = 'damo-agent-zh' damo_agent_zh_mini = 'damo-agent-zh-mini' agent_instruct_all_en = 'agent-instruct-all-en' + msagent_pro = 'msagent-pro' + toolbench = 'toolbench' # coding code_alpaca_en = 'code-alpaca-en' @@ -1004,6 +1006,39 @@ register_dataset( tags=['chat', 'agent', 'multi-round', 'role-play', 'multi-agent']) +def _preprocess_toolbench(dataset: HfDataset) -> HfDataset: + + def reorganize_row(row): + convs = row['conversations'] + sys = convs[0]['value'] + history = [] + history_roles = [] + for idx in range(1, len(convs) - 2, 2): + history.append((convs[idx]['value'], convs[idx + 1]['value'])) + history_roles.append((convs[idx]['from'], convs[idx + 1]['from'])) + + return { + 'system': sys, + 'history': history, + 'history_roles': history_roles, + 'query': convs[-2]['value'], + 'query_role': convs[-2]['from'], + 'response': convs[-1]['value'] + } + + return dataset.map(reorganize_row) + + +register_dataset( + DatasetName.toolbench, + 'swift/ToolBench', + None, + _preprocess_toolbench, + get_dataset_from_repo, + remove_useless_columns=False, + tags=['chat', 'agent', 'multi-round']) + + def _preprocess_hc3(dataset: HfDataset) -> HfDataset: prompt = """Classification Task: Are the following responses from a human or from ChatGPT? Question: {question} diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index 03fcd58b0..151c96b2a 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -224,6 +224,15 @@ class Template: self.truncation_strategy = truncation_strategy self.model = kwargs.get('model', None) self.use_loss_scale = kwargs.get('use_loss_scale', False) + self.response_loss_scale_map = kwargs.get('loss_scale_map', None) + self.query_loss_scale_map = None + if self.response_loss_scale_map is not None: + if 'query' in self.response_loss_scale_map and isinstance(self.response_loss_scale_map['query'], dict): + self.query_loss_scale_map = self.response_loss_scale_map['query'] + if 'response' in self.response_loss_scale_map and isinstance(self.response_loss_scale_map['response'], + dict): + self.response_loss_scale_map = self.response_loss_scale_map['response'] + self.sequence_parallel_size = kwargs.get('sequence_parallel_size', 1) for key in ['prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system']: @@ -277,15 +286,14 @@ class Template: return inputs, tokenizer_kwargs def _concat_context_list( - self, - context_list: List[Context], - res_context_list: List[Context], # inplace - loss_scale_list: List[float], # inplace - system: Optional[str] = None, - query: Optional[str] = None, - response: Optional[str] = None, - round0: Optional[int] = None, - ) -> None: + self, + context_list: List[Context], + res_context_list: List[Context], # inplace + loss_scale_list: List[float], # inplace + system: Optional[str] = None, + query: Optional[str] = None, + response: Optional[str] = None, + round0: Optional[int] = None) -> None: # concat context list and replace placeholder round1 = None if round0 is not None: @@ -295,7 +303,9 @@ class Template: if isinstance(context, str): if '{{RESPONSE}}' == context: assert response is not None - content_part, weight_part = calculate_loss_scale(response, self.use_loss_scale) + content_part, weight_part = calculate_loss_scale(query, response, self.use_loss_scale, + self.response_loss_scale_map, + self.query_loss_scale_map) res_context_list.extend(content_part) loss_scale_list.extend(weight_part) continue @@ -408,7 +418,6 @@ class Template: if q or r: self._concat_context_list( context_list, res_context_list, loss_scale_list, query=q, response=r, round0=i) - res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list) input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(res_context_list, loss_scale_list)