mirror of
https://github.com/modelscope/ms-swift.git
synced 2024-11-25 18:32:44 +08:00
support customizable loss scale (#997)
This commit is contained in:
parent
5d15772d78
commit
1f78d56e14
@ -2,3 +2,4 @@ recursive-include swift/utils *.py
|
||||
recursive-include swift/llm/data *.*
|
||||
recursive-include swift/llm/ds_config *.json
|
||||
recursive-include requirements *.txt
|
||||
recursive-include swift/llm/agent *.json
|
||||
|
@ -252,6 +252,72 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
|
||||
在返回结果的tool_calls中,可以获得调用的函数以及参数信息。
|
||||
|
||||
你也可以通过OpenAI SDK进行测试
|
||||
```python
|
||||
from openai import OpenAI
|
||||
client = OpenAI(
|
||||
api_key='EMPTY',
|
||||
base_url='http://localhost:8000/v1',
|
||||
)
|
||||
query = "What's the weather like in Boston today?"
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': query
|
||||
}]
|
||||
tools = [
|
||||
{
|
||||
"name": "url_for_newapi",
|
||||
"description": "This is the subfunction for tool \"newapi\", you can use this tool.The description of this function is: \"url_for_newapi\"",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "",
|
||||
"example_value": "https://www.instagram.com/reels/CtB6vWMMHFD/"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"url"
|
||||
],
|
||||
"optional": [
|
||||
"url"
|
||||
]
|
||||
}
|
||||
},
|
||||
]
|
||||
resp = client.chat.completions.create(
|
||||
model='llama3-8b-instruct',
|
||||
tools = tools,
|
||||
messages=messages,
|
||||
seed=42)
|
||||
tool_calls = resp.choices[0].message.tool_calls
|
||||
print(f'query: {query}')
|
||||
print(f'tool_calls: {tool_calls}')
|
||||
|
||||
# 流式
|
||||
stream_resp = client.chat.completions.create(
|
||||
model='llama3-8b-instruct',
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
seed=42)
|
||||
|
||||
print(f'query: {query}')
|
||||
print('response: ', end='')
|
||||
for chunk in stream_resp:
|
||||
print(chunk.choices[0].delta.content, end='', flush=True)
|
||||
print()
|
||||
|
||||
"""
|
||||
query: What's the weather like in Boston today?
|
||||
tool_calls: {'id': 'toolcall-e4c637435e754cf9b2034c3e6861a4ad', 'function': {'arguments': ' {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}', 'name': 'url_for_newapi'}, 'type': 'function'}
|
||||
query: What's the weather like in Boston today?
|
||||
response: Thought: I need to find the weather information for Boston today. I can use the 'newapi' tool to get the weather forecast.
|
||||
Action: url_for_newapi
|
||||
Action Input: {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}
|
||||
"""
|
||||
```
|
||||
假设调用返回的结果为`The weather in Boston today is 32°F (0°C), with clear skies`, 我们将结果在role tool字段填入message传入
|
||||
```shell
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
|
@ -124,6 +124,11 @@
|
||||
- `--train_dataset_mix_ratio`: 默认为`0.`. 该参数定义了如何进行数据集打混训练. 指定该参数时, 会混合训练集的`train_dataset_mix_ratio`倍数的`train_dataset_mix_ds`指定的通用知识数据集. 该参数已废弃, 请使用`--dataset`进行数据集混合.
|
||||
- `--train_dataset_mix_ds`: 默认为`['ms-bench']`. 用于防止知识遗忘的通用知识数据集. 该参数已废弃, 请使用`--dataset`进行数据集混合.
|
||||
- `--use_loss_scale`: 默认为`False`. 生效时会将Agent的部分字段(Action/Action Input部分)的loss权重加强以强化CoT, 对普通SFT场景没有任何效果.
|
||||
- `--loss_scale_config_path` 选项指定自定义的 loss_scale 配置,适用于在启用 use_loss_scale 时,例如在 Agent 训练中放大 Action 和其他关键 ReAct 字段的损失权重。
|
||||
- 在配置文件中,您可以使用字典格式来设置 loss_scale。每个键代表一个特定字段名,其关联的值设定了该字段及其后续内容的损失缩放倍数。例如,通过设定 `"Observation:": [2, 0]`,当response包含 `xxxx Observation:error` 时,`Observation:` 字段loss将增加到两倍,`error` 部分的loss则不计入。除了字面匹配,配置也支持正则表达式规则,以实现更灵活的匹配,如模式 '<.*?>':[2.0] 将针对所有尖括号括起来的部分损失增加到两倍。字段匹配与正则匹配所对应的损失缩放倍数,分别由长度为2和1的列表表示。
|
||||
- 同时支持匹配query对整段response设置loss_scale, 这在处理像[Agent-FLAN](https://arxiv.org/abs/2403.12881)论文中描述的固定多轮对话查询时极其有用,如果query中包含了预定义键的任一项,相应的响应将采用关联的 loss_scale 值。,你可以参考`swift/llm/agent/agentflan.json`
|
||||
- 默认情况下,我们为 Action:, Action Input:, Thought:, Final Answer:, 和 Observation: 等字段预设了损失缩放值。我们为[alpha-umi](https://arxiv.org/pdf/2401.07324)和[Agent-FLAN](https://arxiv.org/abs/2403.12881)也提供了默认配置,你可以设置为`alpha-umi`和`agent-flan`来使用。默认的配置文件位于`swift/llm/agent`下
|
||||
- 匹配规则的应用优先级,从高到低为:query字段 > response特定字段 > 正则表达式匹配规则。
|
||||
- `--custom_register_path`: 默认为`None`. 传入`.py`文件, 用于注册模板、模型和数据集.
|
||||
- `--custom_dataset_info`: 默认为`None`, 传入外置dataset_info.json的路径、json字符串或者dict. 用于拓展数据集. 格式参考: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
|
||||
- `--device_map_config_path`: 从本地文件中手动配置模型的device_map, 默认为None
|
||||
|
@ -251,6 +251,73 @@ result
|
||||
{"model":"llama3-8b-instruct","choices":[[{"index":0,"message":{"role":"assistant","content":"Question: What's the weather like in Boston today?\n\nThought: I need to get the current weather in Boston to answer this question.\n\nAction: get_current_weather\n\nAction Input: {'location': 'Boston, MA', 'unit': 'fahrenheit'}\n\nObservation:","tool_calls":{"id":"toolcall-f534d907ae254f2ab96e06c25179ddf9","function":{"arguments":" {'location': 'Boston, MA', 'unit': 'fahrenheit'}\n\n","name":"get_current_weather"},"type":"function"}},"finish_reason":"stop"}]],"usage":{"prompt_tokens":262,"completion_tokens":54,"total_tokens":316},"id":"chatcmpl-8630e8d675c941c0aca958a37633a3c9","object":"chat.completion","created":1717590756}
|
||||
```
|
||||
|
||||
You can also test with OpenAI SDK, for example
|
||||
```python
|
||||
from openai import OpenAI
|
||||
client = OpenAI(
|
||||
api_key='EMPTY',
|
||||
base_url='http://localhost:8000/v1',
|
||||
)
|
||||
query = "What's the weather like in Boston today?"
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': query
|
||||
}]
|
||||
tools = [
|
||||
{
|
||||
"name": "url_for_newapi",
|
||||
"description": "This is the subfunction for tool \"newapi\", you can use this tool.The description of this function is: \"url_for_newapi\"",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "",
|
||||
"example_value": "https://www.instagram.com/reels/CtB6vWMMHFD/"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"url"
|
||||
],
|
||||
"optional": [
|
||||
"url"
|
||||
]
|
||||
}
|
||||
},
|
||||
]
|
||||
resp = client.chat.completions.create(
|
||||
model='llama3-8b-instruct',
|
||||
tools = tools,
|
||||
messages=messages,
|
||||
seed=42)
|
||||
tool_calls = resp.choices[0].message.tool_calls
|
||||
print(f'query: {query}')
|
||||
print(f'tool_calls: {tool_calls}')
|
||||
|
||||
# stream
|
||||
stream_resp = client.chat.completions.create(
|
||||
model='llama3-8b-instruct',
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
seed=42)
|
||||
|
||||
print(f'query: {query}')
|
||||
print('response: ', end='')
|
||||
for chunk in stream_resp:
|
||||
print(chunk.choices[0].delta.content, end='', flush=True)
|
||||
print()
|
||||
|
||||
"""
|
||||
query: What's the weather like in Boston today?
|
||||
tool_calls: {'id': 'toolcall-e4c637435e754cf9b2034c3e6861a4ad', 'function': {'arguments': ' {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}', 'name': 'url_for_newapi'}, 'type': 'function'}
|
||||
query: What's the weather like in Boston today?
|
||||
response: Thought: I need to find the weather information for Boston today. I can use the 'newapi' tool to get the weather forecast.
|
||||
Action: url_for_newapi
|
||||
Action Input: {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}
|
||||
"""
|
||||
```
|
||||
|
||||
In the tool_calls of the returned results, you can obtain the information about the called function and its parameters.
|
||||
|
||||
Assuming the returned result is `The weather in Boston today is 32°F (0°C), with clear skies`, we will fill this result into the role as tool and pass it into the message field.
|
||||
|
@ -125,6 +125,11 @@
|
||||
- `--train_dataset_mix_ratio`: Default is `0.`. This parameter defines how to mix datasets for training. When this parameter is specified, it will mix the training dataset with a multiple of `train_dataset_mix_ratio` of the general knowledge dataset specified by `train_dataset_mix_ds`. This parameter has been deprecated, please use `--dataset {dataset_name}#{dataset_sample}` to mix datasets.
|
||||
- `--train_dataset_mix_ds`: Default is `['ms-bench']`. Used for preventing knowledge forgetting, this is the general knowledge dataset. This parameter has been deprecated, please use `--dataset {dataset_name}#{dataset_sample}` to mix datasets.
|
||||
- `--use_loss_scale`: Default is `False`. When taking effect, strengthens loss weight of some Agent fields (Action/Action Input part) to enhance CoT, has no effect in regular SFT scenarios.
|
||||
- `loss_scale_config_path`: option specifies a custom loss_scale configuration, applicable when use_loss_scale is enabled, such as in Agent training to amplify the loss weights for Action and other crucial ReAct fields.
|
||||
- In the configuration file, you can set the loss_scale using a dictionary format. Each key represents a specific field name, and its associated value specifies the loss scaling factor for that field and its subsequent content. For instance, setting `"Observation:": [2, 0]` means that when the response contains `xxxx Observation:error`, the loss for the `Observation:` field will be doubled, while the loss for the `error` portion will not be counted. Besides literal matching, the configuration also supports regular expression rules for more flexible matching; for example, the pattern `'<.*?>':[2.0]` doubles the loss for any content enclosed in angle brackets. The loss scaling factors for field matching and regex matching are respectively indicated by lists of length 2 and 1.
|
||||
- There is also support for setting loss_scale for the entire response based on matching queries, which is extremely useful in dealing with fixed multi-turn dialogue queries described in the [Agent-Flan paper](https://arxiv.org/abs/2403.12881) paper. If the query includes any of the predefined keys, the corresponding response will use the associated loss_scale value. Refer to swift/llm/agent/agentflan.json for an example.
|
||||
- By default, we have preset loss scaling values for fields such as Action:, Action Input:, Thought:, Final Answer:, and Observation:. We also provide default configurations for [alpha-umi](https://arxiv.org/pdf/2401.07324) and [Agent-FLAN](https://arxiv.org/abs/2403.12881), which you can use by setting to alpha-umi and agent-flan respectively. The default configuration files are located under swift/llm/agent.
|
||||
- The application priority of matching rules is as follows, from highest to lowest: query fields > specific response fields > regular expression matching rules.
|
||||
- `--custom_register_path`: Default is `None`. Pass in a `.py` file used to register templates, models, and datasets.
|
||||
- `--custom_dataset_info`: Default is `None`. Pass in the path to an external `dataset_info.json`, a JSON string, or a dictionary. Used to register custom datasets. The format example: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
|
||||
- `device_map_config_path`: Manually configure the model's device map from a local file, defaults to None.
|
||||
|
22
swift/llm/agent/agentflan.json
Normal file
22
swift/llm/agent/agentflan.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"response":{
|
||||
"Name:": [1.0, 3.0],
|
||||
"Action:": [1.0, 3.0],
|
||||
"ACTION:": [1.0,3.0],
|
||||
"Tool:": [1.0, 3.0],
|
||||
"Command": [1.0, 3.0],
|
||||
"Arguments:": [1.0, 3.0],
|
||||
"action input": [1.0, 3.0],
|
||||
"ACTION_INPUT:":[1.0, 3.0],
|
||||
"Action Input:": [1.0, 3.0],
|
||||
"Thought:": [1.0, 1.0],
|
||||
"Final Answer:": [1.0, 1.0],
|
||||
"Observation:": [2.0, 0.0]
|
||||
},
|
||||
"query":{
|
||||
"What is the tool you want to use": [3.0],
|
||||
"What are the required parameter names": [3.0],
|
||||
"What is the value of": [3.0],
|
||||
"What are the required parameter names for this tool": [3.0]
|
||||
}
|
||||
}
|
8
swift/llm/agent/alpha_umi_loss_scale_config.json
Normal file
8
swift/llm/agent/alpha_umi_loss_scale_config.json
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"Action:": [2.0, 2.0],
|
||||
"Action Input:": [2.0, 2.0],
|
||||
"Thought:": [1.0, 1.0],
|
||||
"Final Answer:": [1.0, 1.0],
|
||||
"Observation:": [2.0, 0.0],
|
||||
"Next:": [2,0, 2.0]
|
||||
}
|
7
swift/llm/agent/default_loss_scale_config.json
Normal file
7
swift/llm/agent/default_loss_scale_config.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"Action:": [2.0, 2.0],
|
||||
"Action Input:": [2.0, 2.0],
|
||||
"Thought:": [1.0, 1.0],
|
||||
"Final Answer:": [1.0, 1.0],
|
||||
"Observation:": [2.0, 0.0]
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from swift.utils import get_logger
|
||||
from swift.utils.utils import split_str_parts_by
|
||||
@ -65,7 +65,11 @@ use function Finish->give_up_and_restart.
|
||||
Specifically, you have access to the following APIs: {tool_list}'''
|
||||
|
||||
|
||||
def calculate_loss_scale(response: str, use_loss_scale=False) -> Tuple[List[str], List[float]]:
|
||||
def calculate_loss_scale(query: str,
|
||||
response: str,
|
||||
use_loss_scale=False,
|
||||
response_loss_scale_map: Optional[dict[str, list]] = None,
|
||||
query_loss_scale_map: Optional[dict[str, list]] = None) -> Tuple[List[str], List[float]]:
|
||||
"""Calculate the loss scale by splitting the agent response.
|
||||
|
||||
This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf
|
||||
@ -90,41 +94,35 @@ def calculate_loss_scale(response: str, use_loss_scale=False) -> Tuple[List[str]
|
||||
Returns:
|
||||
A tuple of agent response parts and their weights.
|
||||
"""
|
||||
if 'Action:' in response and 'Observation:' in response and use_loss_scale:
|
||||
agent_keyword = ['Action:', 'Action Input:', 'Thought:', 'Final Answer:', 'Observation:']
|
||||
agent_parts = split_str_parts_by(response, agent_keyword)
|
||||
if use_loss_scale:
|
||||
# query loss scale map
|
||||
if query_loss_scale_map is not None:
|
||||
for key in query_loss_scale_map.keys():
|
||||
if key in query:
|
||||
if isinstance(query_loss_scale_map[key], (float, int)):
|
||||
query_loss_scale_map[key] = [query_loss_scale_map[key]]
|
||||
loss_scale_value = query_loss_scale_map[key][0]
|
||||
return [response], [float(loss_scale_value)]
|
||||
delimiters = list(k for k in response_loss_scale_map.keys() if len(response_loss_scale_map[k]) == 2)
|
||||
agent_parts = split_str_parts_by(response, delimiters)
|
||||
regex_delimiters = {k: v for k, v in response_loss_scale_map.items() if len(v) == 1}
|
||||
if len(regex_delimiters):
|
||||
agent_parts = split_parts_by_regex(agent_parts, regex_delimiters)
|
||||
weights = []
|
||||
agent_content = []
|
||||
for c in agent_parts:
|
||||
if c['key'] in ('Action:', 'Action Input:'):
|
||||
weights += [2.0]
|
||||
weights += [2.0]
|
||||
elif c['key'] in ('Thought:', 'Final Answer:', ''):
|
||||
weights += [1.0]
|
||||
weights += [1.0]
|
||||
elif c['key'] in ('Observation:', ):
|
||||
weights += [2.0]
|
||||
weights += [0.0]
|
||||
agent_content.append(c['key'])
|
||||
agent_content.append(c['content'])
|
||||
return agent_content, weights
|
||||
elif ('Action:' in response or 'Next:' in response) and use_loss_scale: # alpha-umi
|
||||
agent_keyword = ['Next:', 'Action:', 'Action Input:']
|
||||
agent_parts = split_str_parts_by(response, agent_keyword)
|
||||
weights = []
|
||||
agent_content = []
|
||||
for c in agent_parts:
|
||||
if c['key'] in ('Action:', 'Action Input:', 'Next:'):
|
||||
weights += [2.0]
|
||||
weights += [2.0]
|
||||
elif c['key'] in ('Thought:', 'Final Answer:', ''):
|
||||
weights += [1.0]
|
||||
weights += [1.0]
|
||||
elif c['key'] in ('Observation:', ):
|
||||
weights += [2.0]
|
||||
weights += [0.0]
|
||||
agent_content.append(c['key'])
|
||||
agent_content.append(c['content'])
|
||||
if isinstance(c['key'], (float, int)):
|
||||
weights += [c['key']]
|
||||
agent_content.append(c['content'])
|
||||
else:
|
||||
if c['key'] in response_loss_scale_map:
|
||||
weights += [response_loss_scale_map[c['key']][0]]
|
||||
weights += [response_loss_scale_map[c['key']][1]]
|
||||
agent_content.append(c['key'])
|
||||
agent_content.append(c['content'])
|
||||
else:
|
||||
weights += [1.0]
|
||||
agent_content.append(c['content'])
|
||||
return agent_content, weights
|
||||
else:
|
||||
return [response], [1.0]
|
||||
@ -150,6 +148,31 @@ def split_action_action_input(response: str) -> Tuple[Optional[str], Optional[st
|
||||
return action, action_input
|
||||
|
||||
|
||||
def split_parts_by_regex(text_list: list, regex_delimiters: Dict[str, List[float]]):
|
||||
import re
|
||||
compiled_patterns = [(re.compile(pattern), scale) for pattern, scale in regex_delimiters.items()]
|
||||
for i in range(len(text_list) - 1, -1, -1):
|
||||
item = text_list[i]
|
||||
if item.get('key') == '':
|
||||
res_text = item['content']
|
||||
last_idx = 0
|
||||
segments = []
|
||||
|
||||
for pattern, scale in compiled_patterns:
|
||||
matches = list(re.finditer(pattern, res_text))
|
||||
for match in matches:
|
||||
if match.start() > last_idx:
|
||||
segments.append({'key': '', 'content': res_text[last_idx:match.start()]})
|
||||
segments.append({'key': scale[0], 'content': match.group(0)})
|
||||
last_idx = match.end()
|
||||
|
||||
if last_idx < len(res_text):
|
||||
segments.insert(0, {'key': '', 'content': res_text[last_idx:]})
|
||||
|
||||
if segments:
|
||||
text_list[i:i + 1] = segments
|
||||
|
||||
|
||||
def get_tools_prompt(TOOLS: list[dict[str, Union[str, dict]]], prompt_format: str = 'react_en') -> Optional[str]:
|
||||
tool_descs = []
|
||||
tool_names = []
|
||||
|
@ -162,6 +162,13 @@
|
||||
},
|
||||
"tags": ["chat", "agent", "multi-round", "🔥"]
|
||||
},
|
||||
"msagent-pro": {
|
||||
"dataset_id": "iic/MSAgent-Pro",
|
||||
"conversations": {
|
||||
"error_strategy": "delete"
|
||||
},
|
||||
"tags": ["chat", "agent", "multi-round", "🔥"]
|
||||
},
|
||||
"codefuse-python-en": {
|
||||
"dataset_id": "codefuse-ai/CodeExercise-Python-27k",
|
||||
"conversations": {
|
||||
|
@ -197,6 +197,12 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
|
||||
if use_model:
|
||||
template_kwargs['model'] = model
|
||||
template_kwargs['use_loss_scale'] = args.use_loss_scale
|
||||
if args.loss_scale_config_path is not None:
|
||||
cwd = os.getcwd()
|
||||
config_path = args.loss_scale_config_path if os.path.isabs(args.loss_scale_config_path) else os.path.join(
|
||||
cwd, args.loss_scale_config_path)
|
||||
with open(config_path, 'r') as json_file:
|
||||
template_kwargs['loss_scale_map'] = json.load(json_file)
|
||||
template_kwargs['tools_prompt'] = args.tools_prompt
|
||||
if args.sequence_parallel_size and args.sequence_parallel_size > 1:
|
||||
template_kwargs['sequence_parallel_size'] = args.sequence_parallel_size
|
||||
|
@ -461,6 +461,7 @@ class SftArguments(ArgumentsBase):
|
||||
dataset_seed: int = 42
|
||||
dataset_test_ratio: float = 0.01
|
||||
use_loss_scale: bool = False # for agent
|
||||
loss_scale_config_path: str = 'DEFAULT'
|
||||
system: Optional[str] = None
|
||||
tools_prompt: Literal['react_en', 'react_zh', 'toolbench'] = 'react_en'
|
||||
max_length: int = 2048 # -1: no limit
|
||||
@ -753,7 +754,16 @@ class SftArguments(ArgumentsBase):
|
||||
if self.deepspeed == ds_name:
|
||||
self.deepspeed = os.path.join(ds_config_folder, ds_config)
|
||||
break
|
||||
|
||||
if self.loss_scale_config_path:
|
||||
if self.loss_scale_config_path == 'DEFAULT':
|
||||
self.loss_scale_config_path = os.path.abspath(
|
||||
os.path.join(__file__, '..', '..', 'agent', 'default_loss_scale_config.json'))
|
||||
elif self.loss_scale_config_path == 'alpha-umi': # https://arxiv.org/pdf/2401.07324
|
||||
self.loss_scale_config_path = os.path.abspath(
|
||||
os.path.join(__file__, '..', '..', 'agent', 'alpha_umi_loss_scale_config.json'))
|
||||
elif self.loss_scale_config_path == 'agent-flan': # https://arxiv.org/abs/2403.12881
|
||||
self.loss_scale_config_path = os.path.abspath(
|
||||
os.path.join(__file__, '..', '..', 'agent', 'agentflan.json'))
|
||||
self.handle_path()
|
||||
self._handle_dataset_sample()
|
||||
self._register_self_cognition()
|
||||
|
@ -72,6 +72,8 @@ class DatasetName:
|
||||
damo_agent_zh = 'damo-agent-zh'
|
||||
damo_agent_zh_mini = 'damo-agent-zh-mini'
|
||||
agent_instruct_all_en = 'agent-instruct-all-en'
|
||||
msagent_pro = 'msagent-pro'
|
||||
toolbench = 'toolbench'
|
||||
|
||||
# coding
|
||||
code_alpaca_en = 'code-alpaca-en'
|
||||
@ -1004,6 +1006,39 @@ register_dataset(
|
||||
tags=['chat', 'agent', 'multi-round', 'role-play', 'multi-agent'])
|
||||
|
||||
|
||||
def _preprocess_toolbench(dataset: HfDataset) -> HfDataset:
|
||||
|
||||
def reorganize_row(row):
|
||||
convs = row['conversations']
|
||||
sys = convs[0]['value']
|
||||
history = []
|
||||
history_roles = []
|
||||
for idx in range(1, len(convs) - 2, 2):
|
||||
history.append((convs[idx]['value'], convs[idx + 1]['value']))
|
||||
history_roles.append((convs[idx]['from'], convs[idx + 1]['from']))
|
||||
|
||||
return {
|
||||
'system': sys,
|
||||
'history': history,
|
||||
'history_roles': history_roles,
|
||||
'query': convs[-2]['value'],
|
||||
'query_role': convs[-2]['from'],
|
||||
'response': convs[-1]['value']
|
||||
}
|
||||
|
||||
return dataset.map(reorganize_row)
|
||||
|
||||
|
||||
register_dataset(
|
||||
DatasetName.toolbench,
|
||||
'swift/ToolBench',
|
||||
None,
|
||||
_preprocess_toolbench,
|
||||
get_dataset_from_repo,
|
||||
remove_useless_columns=False,
|
||||
tags=['chat', 'agent', 'multi-round'])
|
||||
|
||||
|
||||
def _preprocess_hc3(dataset: HfDataset) -> HfDataset:
|
||||
prompt = """Classification Task: Are the following responses from a human or from ChatGPT?
|
||||
Question: {question}
|
||||
|
@ -224,6 +224,15 @@ class Template:
|
||||
self.truncation_strategy = truncation_strategy
|
||||
self.model = kwargs.get('model', None)
|
||||
self.use_loss_scale = kwargs.get('use_loss_scale', False)
|
||||
self.response_loss_scale_map = kwargs.get('loss_scale_map', None)
|
||||
self.query_loss_scale_map = None
|
||||
if self.response_loss_scale_map is not None:
|
||||
if 'query' in self.response_loss_scale_map and isinstance(self.response_loss_scale_map['query'], dict):
|
||||
self.query_loss_scale_map = self.response_loss_scale_map['query']
|
||||
if 'response' in self.response_loss_scale_map and isinstance(self.response_loss_scale_map['response'],
|
||||
dict):
|
||||
self.response_loss_scale_map = self.response_loss_scale_map['response']
|
||||
|
||||
self.sequence_parallel_size = kwargs.get('sequence_parallel_size', 1)
|
||||
|
||||
for key in ['prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system']:
|
||||
@ -277,15 +286,14 @@ class Template:
|
||||
return inputs, tokenizer_kwargs
|
||||
|
||||
def _concat_context_list(
|
||||
self,
|
||||
context_list: List[Context],
|
||||
res_context_list: List[Context], # inplace
|
||||
loss_scale_list: List[float], # inplace
|
||||
system: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
response: Optional[str] = None,
|
||||
round0: Optional[int] = None,
|
||||
) -> None:
|
||||
self,
|
||||
context_list: List[Context],
|
||||
res_context_list: List[Context], # inplace
|
||||
loss_scale_list: List[float], # inplace
|
||||
system: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
response: Optional[str] = None,
|
||||
round0: Optional[int] = None) -> None:
|
||||
# concat context list and replace placeholder
|
||||
round1 = None
|
||||
if round0 is not None:
|
||||
@ -295,7 +303,9 @@ class Template:
|
||||
if isinstance(context, str):
|
||||
if '{{RESPONSE}}' == context:
|
||||
assert response is not None
|
||||
content_part, weight_part = calculate_loss_scale(response, self.use_loss_scale)
|
||||
content_part, weight_part = calculate_loss_scale(query, response, self.use_loss_scale,
|
||||
self.response_loss_scale_map,
|
||||
self.query_loss_scale_map)
|
||||
res_context_list.extend(content_part)
|
||||
loss_scale_list.extend(weight_part)
|
||||
continue
|
||||
@ -408,7 +418,6 @@ class Template:
|
||||
if q or r:
|
||||
self._concat_context_list(
|
||||
context_list, res_context_list, loss_scale_list, query=q, response=r, round0=i)
|
||||
|
||||
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list)
|
||||
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(res_context_list, loss_scale_list)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user