support customizable loss scale (#997)

This commit is contained in:
jinghanhu 2024-06-13 21:40:14 +08:00 committed by GitHub
parent 5d15772d78
commit 1f78d56e14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 317 additions and 46 deletions

View File

@ -2,3 +2,4 @@ recursive-include swift/utils *.py
recursive-include swift/llm/data *.*
recursive-include swift/llm/ds_config *.json
recursive-include requirements *.txt
recursive-include swift/llm/agent *.json

View File

@ -252,6 +252,72 @@ curl -X POST http://localhost:8000/v1/chat/completions \
在返回结果的tool_calls中可以获得调用的函数以及参数信息。
你也可以通过OpenAI SDK进行测试
```python
from openai import OpenAI
client = OpenAI(
api_key='EMPTY',
base_url='http://localhost:8000/v1',
)
query = "What's the weather like in Boston today?"
messages = [{
'role': 'user',
'content': query
}]
tools = [
{
"name": "url_for_newapi",
"description": "This is the subfunction for tool \"newapi\", you can use this tool.The description of this function is: \"url_for_newapi\"",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "",
"example_value": "https://www.instagram.com/reels/CtB6vWMMHFD/"
}
},
"required": [
"url"
],
"optional": [
"url"
]
}
},
]
resp = client.chat.completions.create(
model='llama3-8b-instruct',
tools = tools,
messages=messages,
seed=42)
tool_calls = resp.choices[0].message.tool_calls
print(f'query: {query}')
print(f'tool_calls: {tool_calls}')
# 流式
stream_resp = client.chat.completions.create(
model='llama3-8b-instruct',
messages=messages,
tools=tools,
stream=True,
seed=42)
print(f'query: {query}')
print('response: ', end='')
for chunk in stream_resp:
print(chunk.choices[0].delta.content, end='', flush=True)
print()
"""
query: What's the weather like in Boston today?
tool_calls: {'id': 'toolcall-e4c637435e754cf9b2034c3e6861a4ad', 'function': {'arguments': ' {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}', 'name': 'url_for_newapi'}, 'type': 'function'}
query: What's the weather like in Boston today?
response: Thought: I need to find the weather information for Boston today. I can use the 'newapi' tool to get the weather forecast.
Action: url_for_newapi
Action Input: {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}
"""
```
假设调用返回的结果为`The weather in Boston today is 32°F (0°C), with clear skies`, 我们将结果在role tool字段填入message传入
```shell
curl -X POST http://localhost:8000/v1/chat/completions \

View File

@ -124,6 +124,11 @@
- `--train_dataset_mix_ratio`: 默认为`0.`. 该参数定义了如何进行数据集打混训练. 指定该参数时, 会混合训练集的`train_dataset_mix_ratio`倍数的`train_dataset_mix_ds`指定的通用知识数据集. 该参数已废弃, 请使用`--dataset`进行数据集混合.
- `--train_dataset_mix_ds`: 默认为`['ms-bench']`. 用于防止知识遗忘的通用知识数据集. 该参数已废弃, 请使用`--dataset`进行数据集混合.
- `--use_loss_scale`: 默认为`False`. 生效时会将Agent的部分字段(Action/Action Input部分)的loss权重加强以强化CoT, 对普通SFT场景没有任何效果.
- `--loss_scale_config_path` 选项指定自定义的 loss_scale 配置,适用于在启用 use_loss_scale 时,例如在 Agent 训练中放大 Action 和其他关键 ReAct 字段的损失权重。
- 在配置文件中,您可以使用字典格式来设置 loss_scale。每个键代表一个特定字段名其关联的值设定了该字段及其后续内容的损失缩放倍数。例如通过设定 `"Observation:": [2, 0]`当response包含 `xxxx Observation:error` 时,`Observation:` 字段loss将增加到两倍`error` 部分的loss则不计入。除了字面匹配配置也支持正则表达式规则以实现更灵活的匹配如模式 '<.*?>':[2.0] 将针对所有尖括号括起来的部分损失增加到两倍。字段匹配与正则匹配所对应的损失缩放倍数分别由长度为2和1的列表表示。
- 同时支持匹配query对整段response设置loss_scale, 这在处理像[Agent-FLAN](https://arxiv.org/abs/2403.12881)论文中描述的固定多轮对话查询时极其有用如果query中包含了预定义键的任一项相应的响应将采用关联的 loss_scale 值。,你可以参考`swift/llm/agent/agentflan.json`
- 默认情况下,我们为 Action:, Action Input:, Thought:, Final Answer:, 和 Observation: 等字段预设了损失缩放值。我们为[alpha-umi](https://arxiv.org/pdf/2401.07324)和[Agent-FLAN](https://arxiv.org/abs/2403.12881)也提供了默认配置,你可以设置为`alpha-umi`和`agent-flan`来使用。默认的配置文件位于`swift/llm/agent`下
- 匹配规则的应用优先级从高到低为query字段 > response特定字段 > 正则表达式匹配规则。
- `--custom_register_path`: 默认为`None`. 传入`.py`文件, 用于注册模板、模型和数据集.
- `--custom_dataset_info`: 默认为`None`, 传入外置dataset_info.json的路径、json字符串或者dict. 用于拓展数据集. 格式参考: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
- `--device_map_config_path`: 从本地文件中手动配置模型的device_map, 默认为None

View File

@ -251,6 +251,73 @@ result
{"model":"llama3-8b-instruct","choices":[[{"index":0,"message":{"role":"assistant","content":"Question: What's the weather like in Boston today?\n\nThought: I need to get the current weather in Boston to answer this question.\n\nAction: get_current_weather\n\nAction Input: {'location': 'Boston, MA', 'unit': 'fahrenheit'}\n\nObservation:","tool_calls":{"id":"toolcall-f534d907ae254f2ab96e06c25179ddf9","function":{"arguments":" {'location': 'Boston, MA', 'unit': 'fahrenheit'}\n\n","name":"get_current_weather"},"type":"function"}},"finish_reason":"stop"}]],"usage":{"prompt_tokens":262,"completion_tokens":54,"total_tokens":316},"id":"chatcmpl-8630e8d675c941c0aca958a37633a3c9","object":"chat.completion","created":1717590756}
```
You can also test with OpenAI SDK, for example
```python
from openai import OpenAI
client = OpenAI(
api_key='EMPTY',
base_url='http://localhost:8000/v1',
)
query = "What's the weather like in Boston today?"
messages = [{
'role': 'user',
'content': query
}]
tools = [
{
"name": "url_for_newapi",
"description": "This is the subfunction for tool \"newapi\", you can use this tool.The description of this function is: \"url_for_newapi\"",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "",
"example_value": "https://www.instagram.com/reels/CtB6vWMMHFD/"
}
},
"required": [
"url"
],
"optional": [
"url"
]
}
},
]
resp = client.chat.completions.create(
model='llama3-8b-instruct',
tools = tools,
messages=messages,
seed=42)
tool_calls = resp.choices[0].message.tool_calls
print(f'query: {query}')
print(f'tool_calls: {tool_calls}')
# stream
stream_resp = client.chat.completions.create(
model='llama3-8b-instruct',
messages=messages,
tools=tools,
stream=True,
seed=42)
print(f'query: {query}')
print('response: ', end='')
for chunk in stream_resp:
print(chunk.choices[0].delta.content, end='', flush=True)
print()
"""
query: What's the weather like in Boston today?
tool_calls: {'id': 'toolcall-e4c637435e754cf9b2034c3e6861a4ad', 'function': {'arguments': ' {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}', 'name': 'url_for_newapi'}, 'type': 'function'}
query: What's the weather like in Boston today?
response: Thought: I need to find the weather information for Boston today. I can use the 'newapi' tool to get the weather forecast.
Action: url_for_newapi
Action Input: {"url": "https://api.weatherapi.com/v1/current.json?key=YOUR_API_KEY&q=Boston"}
"""
```
In the tool_calls of the returned results, you can obtain the information about the called function and its parameters.
Assuming the returned result is `The weather in Boston today is 32°F (0°C), with clear skies`, we will fill this result into the role as tool and pass it into the message field.

View File

@ -125,6 +125,11 @@
- `--train_dataset_mix_ratio`: Default is `0.`. This parameter defines how to mix datasets for training. When this parameter is specified, it will mix the training dataset with a multiple of `train_dataset_mix_ratio` of the general knowledge dataset specified by `train_dataset_mix_ds`. This parameter has been deprecated, please use `--dataset {dataset_name}#{dataset_sample}` to mix datasets.
- `--train_dataset_mix_ds`: Default is `['ms-bench']`. Used for preventing knowledge forgetting, this is the general knowledge dataset. This parameter has been deprecated, please use `--dataset {dataset_name}#{dataset_sample}` to mix datasets.
- `--use_loss_scale`: Default is `False`. When taking effect, strengthens loss weight of some Agent fields (Action/Action Input part) to enhance CoT, has no effect in regular SFT scenarios.
- `loss_scale_config_path`: option specifies a custom loss_scale configuration, applicable when use_loss_scale is enabled, such as in Agent training to amplify the loss weights for Action and other crucial ReAct fields.
- In the configuration file, you can set the loss_scale using a dictionary format. Each key represents a specific field name, and its associated value specifies the loss scaling factor for that field and its subsequent content. For instance, setting `"Observation:": [2, 0]` means that when the response contains `xxxx Observation:error`, the loss for the `Observation:` field will be doubled, while the loss for the `error` portion will not be counted. Besides literal matching, the configuration also supports regular expression rules for more flexible matching; for example, the pattern `'<.*?>':[2.0]` doubles the loss for any content enclosed in angle brackets. The loss scaling factors for field matching and regex matching are respectively indicated by lists of length 2 and 1.
- There is also support for setting loss_scale for the entire response based on matching queries, which is extremely useful in dealing with fixed multi-turn dialogue queries described in the [Agent-Flan paper](https://arxiv.org/abs/2403.12881) paper. If the query includes any of the predefined keys, the corresponding response will use the associated loss_scale value. Refer to swift/llm/agent/agentflan.json for an example.
- By default, we have preset loss scaling values for fields such as Action:, Action Input:, Thought:, Final Answer:, and Observation:. We also provide default configurations for [alpha-umi](https://arxiv.org/pdf/2401.07324) and [Agent-FLAN](https://arxiv.org/abs/2403.12881), which you can use by setting to alpha-umi and agent-flan respectively. The default configuration files are located under swift/llm/agent.
- The application priority of matching rules is as follows, from highest to lowest: query fields > specific response fields > regular expression matching rules.
- `--custom_register_path`: Default is `None`. Pass in a `.py` file used to register templates, models, and datasets.
- `--custom_dataset_info`: Default is `None`. Pass in the path to an external `dataset_info.json`, a JSON string, or a dictionary. Used to register custom datasets. The format example: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
- `device_map_config_path`: Manually configure the model's device map from a local file, defaults to None.

View File

@ -0,0 +1,22 @@
{
"response":{
"Name:": [1.0, 3.0],
"Action:": [1.0, 3.0],
"ACTION:": [1.0,3.0],
"Tool:": [1.0, 3.0],
"Command": [1.0, 3.0],
"Arguments:": [1.0, 3.0],
"action input": [1.0, 3.0],
"ACTION_INPUT:":[1.0, 3.0],
"Action Input:": [1.0, 3.0],
"Thought:": [1.0, 1.0],
"Final Answer:": [1.0, 1.0],
"Observation:": [2.0, 0.0]
},
"query":{
"What is the tool you want to use": [3.0],
"What are the required parameter names": [3.0],
"What is the value of": [3.0],
"What are the required parameter names for this tool": [3.0]
}
}

View File

@ -0,0 +1,8 @@
{
"Action:": [2.0, 2.0],
"Action Input:": [2.0, 2.0],
"Thought:": [1.0, 1.0],
"Final Answer:": [1.0, 1.0],
"Observation:": [2.0, 0.0],
"Next:": [2,0, 2.0]
}

View File

@ -0,0 +1,7 @@
{
"Action:": [2.0, 2.0],
"Action Input:": [2.0, 2.0],
"Thought:": [1.0, 1.0],
"Final Answer:": [1.0, 1.0],
"Observation:": [2.0, 0.0]
}

View File

@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from swift.utils import get_logger
from swift.utils.utils import split_str_parts_by
@ -65,7 +65,11 @@ use function Finish->give_up_and_restart.
Specifically, you have access to the following APIs: {tool_list}'''
def calculate_loss_scale(response: str, use_loss_scale=False) -> Tuple[List[str], List[float]]:
def calculate_loss_scale(query: str,
response: str,
use_loss_scale=False,
response_loss_scale_map: Optional[dict[str, list]] = None,
query_loss_scale_map: Optional[dict[str, list]] = None) -> Tuple[List[str], List[float]]:
"""Calculate the loss scale by splitting the agent response.
This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf
@ -90,41 +94,35 @@ def calculate_loss_scale(response: str, use_loss_scale=False) -> Tuple[List[str]
Returns:
A tuple of agent response parts and their weights.
"""
if 'Action:' in response and 'Observation:' in response and use_loss_scale:
agent_keyword = ['Action:', 'Action Input:', 'Thought:', 'Final Answer:', 'Observation:']
agent_parts = split_str_parts_by(response, agent_keyword)
if use_loss_scale:
# query loss scale map
if query_loss_scale_map is not None:
for key in query_loss_scale_map.keys():
if key in query:
if isinstance(query_loss_scale_map[key], (float, int)):
query_loss_scale_map[key] = [query_loss_scale_map[key]]
loss_scale_value = query_loss_scale_map[key][0]
return [response], [float(loss_scale_value)]
delimiters = list(k for k in response_loss_scale_map.keys() if len(response_loss_scale_map[k]) == 2)
agent_parts = split_str_parts_by(response, delimiters)
regex_delimiters = {k: v for k, v in response_loss_scale_map.items() if len(v) == 1}
if len(regex_delimiters):
agent_parts = split_parts_by_regex(agent_parts, regex_delimiters)
weights = []
agent_content = []
for c in agent_parts:
if c['key'] in ('Action:', 'Action Input:'):
weights += [2.0]
weights += [2.0]
elif c['key'] in ('Thought:', 'Final Answer:', ''):
weights += [1.0]
weights += [1.0]
elif c['key'] in ('Observation:', ):
weights += [2.0]
weights += [0.0]
agent_content.append(c['key'])
agent_content.append(c['content'])
return agent_content, weights
elif ('Action:' in response or 'Next:' in response) and use_loss_scale: # alpha-umi
agent_keyword = ['Next:', 'Action:', 'Action Input:']
agent_parts = split_str_parts_by(response, agent_keyword)
weights = []
agent_content = []
for c in agent_parts:
if c['key'] in ('Action:', 'Action Input:', 'Next:'):
weights += [2.0]
weights += [2.0]
elif c['key'] in ('Thought:', 'Final Answer:', ''):
weights += [1.0]
weights += [1.0]
elif c['key'] in ('Observation:', ):
weights += [2.0]
weights += [0.0]
agent_content.append(c['key'])
agent_content.append(c['content'])
if isinstance(c['key'], (float, int)):
weights += [c['key']]
agent_content.append(c['content'])
else:
if c['key'] in response_loss_scale_map:
weights += [response_loss_scale_map[c['key']][0]]
weights += [response_loss_scale_map[c['key']][1]]
agent_content.append(c['key'])
agent_content.append(c['content'])
else:
weights += [1.0]
agent_content.append(c['content'])
return agent_content, weights
else:
return [response], [1.0]
@ -150,6 +148,31 @@ def split_action_action_input(response: str) -> Tuple[Optional[str], Optional[st
return action, action_input
def split_parts_by_regex(text_list: list, regex_delimiters: Dict[str, List[float]]):
import re
compiled_patterns = [(re.compile(pattern), scale) for pattern, scale in regex_delimiters.items()]
for i in range(len(text_list) - 1, -1, -1):
item = text_list[i]
if item.get('key') == '':
res_text = item['content']
last_idx = 0
segments = []
for pattern, scale in compiled_patterns:
matches = list(re.finditer(pattern, res_text))
for match in matches:
if match.start() > last_idx:
segments.append({'key': '', 'content': res_text[last_idx:match.start()]})
segments.append({'key': scale[0], 'content': match.group(0)})
last_idx = match.end()
if last_idx < len(res_text):
segments.insert(0, {'key': '', 'content': res_text[last_idx:]})
if segments:
text_list[i:i + 1] = segments
def get_tools_prompt(TOOLS: list[dict[str, Union[str, dict]]], prompt_format: str = 'react_en') -> Optional[str]:
tool_descs = []
tool_names = []

View File

@ -162,6 +162,13 @@
},
"tags": ["chat", "agent", "multi-round", "🔥"]
},
"msagent-pro": {
"dataset_id": "iic/MSAgent-Pro",
"conversations": {
"error_strategy": "delete"
},
"tags": ["chat", "agent", "multi-round", "🔥"]
},
"codefuse-python-en": {
"dataset_id": "codefuse-ai/CodeExercise-Python-27k",
"conversations": {

View File

@ -197,6 +197,12 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
if use_model:
template_kwargs['model'] = model
template_kwargs['use_loss_scale'] = args.use_loss_scale
if args.loss_scale_config_path is not None:
cwd = os.getcwd()
config_path = args.loss_scale_config_path if os.path.isabs(args.loss_scale_config_path) else os.path.join(
cwd, args.loss_scale_config_path)
with open(config_path, 'r') as json_file:
template_kwargs['loss_scale_map'] = json.load(json_file)
template_kwargs['tools_prompt'] = args.tools_prompt
if args.sequence_parallel_size and args.sequence_parallel_size > 1:
template_kwargs['sequence_parallel_size'] = args.sequence_parallel_size

View File

@ -461,6 +461,7 @@ class SftArguments(ArgumentsBase):
dataset_seed: int = 42
dataset_test_ratio: float = 0.01
use_loss_scale: bool = False # for agent
loss_scale_config_path: str = 'DEFAULT'
system: Optional[str] = None
tools_prompt: Literal['react_en', 'react_zh', 'toolbench'] = 'react_en'
max_length: int = 2048 # -1: no limit
@ -753,7 +754,16 @@ class SftArguments(ArgumentsBase):
if self.deepspeed == ds_name:
self.deepspeed = os.path.join(ds_config_folder, ds_config)
break
if self.loss_scale_config_path:
if self.loss_scale_config_path == 'DEFAULT':
self.loss_scale_config_path = os.path.abspath(
os.path.join(__file__, '..', '..', 'agent', 'default_loss_scale_config.json'))
elif self.loss_scale_config_path == 'alpha-umi': # https://arxiv.org/pdf/2401.07324
self.loss_scale_config_path = os.path.abspath(
os.path.join(__file__, '..', '..', 'agent', 'alpha_umi_loss_scale_config.json'))
elif self.loss_scale_config_path == 'agent-flan': # https://arxiv.org/abs/2403.12881
self.loss_scale_config_path = os.path.abspath(
os.path.join(__file__, '..', '..', 'agent', 'agentflan.json'))
self.handle_path()
self._handle_dataset_sample()
self._register_self_cognition()

View File

@ -72,6 +72,8 @@ class DatasetName:
damo_agent_zh = 'damo-agent-zh'
damo_agent_zh_mini = 'damo-agent-zh-mini'
agent_instruct_all_en = 'agent-instruct-all-en'
msagent_pro = 'msagent-pro'
toolbench = 'toolbench'
# coding
code_alpaca_en = 'code-alpaca-en'
@ -1004,6 +1006,39 @@ register_dataset(
tags=['chat', 'agent', 'multi-round', 'role-play', 'multi-agent'])
def _preprocess_toolbench(dataset: HfDataset) -> HfDataset:
def reorganize_row(row):
convs = row['conversations']
sys = convs[0]['value']
history = []
history_roles = []
for idx in range(1, len(convs) - 2, 2):
history.append((convs[idx]['value'], convs[idx + 1]['value']))
history_roles.append((convs[idx]['from'], convs[idx + 1]['from']))
return {
'system': sys,
'history': history,
'history_roles': history_roles,
'query': convs[-2]['value'],
'query_role': convs[-2]['from'],
'response': convs[-1]['value']
}
return dataset.map(reorganize_row)
register_dataset(
DatasetName.toolbench,
'swift/ToolBench',
None,
_preprocess_toolbench,
get_dataset_from_repo,
remove_useless_columns=False,
tags=['chat', 'agent', 'multi-round'])
def _preprocess_hc3(dataset: HfDataset) -> HfDataset:
prompt = """Classification Task: Are the following responses from a human or from ChatGPT?
Question: {question}

View File

@ -224,6 +224,15 @@ class Template:
self.truncation_strategy = truncation_strategy
self.model = kwargs.get('model', None)
self.use_loss_scale = kwargs.get('use_loss_scale', False)
self.response_loss_scale_map = kwargs.get('loss_scale_map', None)
self.query_loss_scale_map = None
if self.response_loss_scale_map is not None:
if 'query' in self.response_loss_scale_map and isinstance(self.response_loss_scale_map['query'], dict):
self.query_loss_scale_map = self.response_loss_scale_map['query']
if 'response' in self.response_loss_scale_map and isinstance(self.response_loss_scale_map['response'],
dict):
self.response_loss_scale_map = self.response_loss_scale_map['response']
self.sequence_parallel_size = kwargs.get('sequence_parallel_size', 1)
for key in ['prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system']:
@ -277,15 +286,14 @@ class Template:
return inputs, tokenizer_kwargs
def _concat_context_list(
self,
context_list: List[Context],
res_context_list: List[Context], # inplace
loss_scale_list: List[float], # inplace
system: Optional[str] = None,
query: Optional[str] = None,
response: Optional[str] = None,
round0: Optional[int] = None,
) -> None:
self,
context_list: List[Context],
res_context_list: List[Context], # inplace
loss_scale_list: List[float], # inplace
system: Optional[str] = None,
query: Optional[str] = None,
response: Optional[str] = None,
round0: Optional[int] = None) -> None:
# concat context list and replace placeholder
round1 = None
if round0 is not None:
@ -295,7 +303,9 @@ class Template:
if isinstance(context, str):
if '{{RESPONSE}}' == context:
assert response is not None
content_part, weight_part = calculate_loss_scale(response, self.use_loss_scale)
content_part, weight_part = calculate_loss_scale(query, response, self.use_loss_scale,
self.response_loss_scale_map,
self.query_loss_scale_map)
res_context_list.extend(content_part)
loss_scale_list.extend(weight_part)
continue
@ -408,7 +418,6 @@ class Template:
if q or r:
self._concat_context_list(
context_list, res_context_list, loss_scale_list, query=q, response=r, round0=i)
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list)
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(res_context_list, loss_scale_list)