Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions client_tools/client_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,15 @@ def generate_client_class(class_name: str,
"""Generate client wrapper class code."""

def build_imports() -> Tuple[List[str], str]:
# Include both method signatures and __init__ signature for import detection
signatures = [sig for _, sig in methods]
# Include both method signatures and __init__ signature for import detection.
# Use the timeout-injected signature for ``encode`` so the ``Optional``
# import introduced by ``_inject_timeout`` is detected.
signatures = []
for name, sig in methods:
if name == 'encode' and class_name in ('Dataset', 'LazyDataset'):
signatures.append(_inject_timeout(sig))
else:
signatures.append(sig)
if init_signature:
signatures.append(init_signature)

Expand Down Expand Up @@ -261,15 +268,31 @@ def build_imports() -> Tuple[List[str], str]:
lines.append('')
return lines, inheritance

def _inject_timeout(signature: str) -> str:
"""Insert `timeout: Optional[int] = 600` before any **kwargs in the signature."""
if 'timeout' in signature:
return signature
if ', **' in signature:
pre, post = signature.rsplit(', **', 1)
return f'{pre}, timeout: Optional[int] = 600, **{post}'
if signature.startswith('**'):
return f'timeout: Optional[int] = 600, {signature}'
if signature:
return f'{signature}, timeout: Optional[int] = 600'
return 'timeout: Optional[int] = 600'

def build_method(name: str, signature: str) -> str:
param_names = parse_params_from_signature(signature)
kwargs_dict = '{' + ', '.join(f"'{p}': {p}" for p in param_names) + '}' if param_names else '{}'
sig_part = f', {signature}' if signature else ''
wants_timeout = name == 'encode' and class_name in ('Dataset', 'LazyDataset')
effective_sig = _inject_timeout(signature) if wants_timeout else signature
sig_part = f', {effective_sig}' if effective_sig else ''
if 'kwargs' in sig_part:
extra_args = '\n **kwargs'
else:
extra_args = ''
ret = 'self' if name == '__iter__' else 'response.json()["result"]'
timeout_kwarg = ',\n timeout=timeout' if wants_timeout else ''

code = f'''
def {name}(self{sig_part}):
Expand All @@ -279,7 +302,7 @@ def {name}(self{sig_part}):
'processor_id': self.processor_id,
'function': '{name}',
**{kwargs_dict},{extra_args}
}}
}}{timeout_kwarg}
)
response.raise_for_status()
return {ret}
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Components/Dataset/Dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ dataset.encode()
> 1. Dataset's `map`, `encode`, `filter`, and other methods all use the `map` method of `datasets`, so you can use the corresponding parameters in the kwargs of the corresponding methods
> 2. The `load_from_cache_file` parameter defaults to False, because when this parameter is set to True, it can cause headaches when the dataset changes but training still uses the cache. If your dataset is large and updated infrequently, you can directly set it to True
> 3. encode does not need to specify `DatasetMeta` because after preprocessing, all datasets have the same format
> 4. `encode` tokenizes with a single process by default. For large datasets, enable multi-process parallelism via `num_proc`, e.g. `dataset.encode(num_proc=8)`

6. Getting data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ dataset.map('SelfCognitionProcessor',

# Encode dataset into tokens usable by the model
dataset.encode(batched=True)
# For large datasets, use num_proc to enable multi-process parallelism:
# dataset.encode(batched=True, num_proc=8)
# When using twinkle_client.dataset, encode calls the remote server over HTTP
# with a default 600s timeout; raise it via the timeout argument if needed:
# dataset.encode(batched=True, num_proc=8, timeout=3600)

# Create DataLoader
dataloader = DataLoader(dataset=dataset, batch_size=4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ dataset.map('SelfCognitionProcessor',

# 编码数据集为模型可用的 token
dataset.encode(batched=True)
# 数据量大时可用 num_proc 多进程加速:
# dataset.encode(batched=True, num_proc=8)
# 使用 twinkle_client.dataset 时,encode 是通过 HTTP 调用远端服务,
# 默认 600 秒超时,可用 timeout 参数按需调大:
# dataset.encode(batched=True, num_proc=8, timeout=3600)

# 创建 DataLoader
dataloader = DataLoader(dataset=dataset, batch_size=4)
Expand Down
1 change: 1 addition & 0 deletions docs/source_zh/组件/数据集/Dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ dataset.encode()
> 1. Dataset 的 `map`、`encode`、`filter` 等方法均使用 `datasets` 的 `map` 方式进行,因此在对应方法的 kwargs 中均可以使用对应的参数
> 2. `load_from_cache_file` 参数默认为 False,因为该参数设置为 True 时会引发一些数据集改变但训练仍然使用缓存的头疼问题。如果你的数据集较大而且更新不频繁,可以直接置为 True
> 3. encode 不需要指定 `DatasetMeta`,因为预处理过后所有数据集格式都是相同的
> 4. `encode` 默认使用单进程分词。数据量较大时可通过 `num_proc` 开启多进程并行加速,例如 `dataset.encode(num_proc=8)`

6. 获取数据

Expand Down
34 changes: 31 additions & 3 deletions src/twinkle_client/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# 2. Run: python client_tools/client_generator.py
# ============================================================================

from typing import Any, Callable, Dict, Type, Union
from typing import Any, Callable, Dict, Optional, Type, Union
from twinkle_client.http import http_post
from twinkle.dataset import Dataset
from twinkle.dataset import DatasetMeta
Expand Down Expand Up @@ -50,15 +50,16 @@ def set_template(self, template_func: Union[Template, Type[Template], str], **kw
return response.json()["result"]


def encode(self, add_generation_prompt: bool = False, **kwargs):
def encode(self, add_generation_prompt: bool = False, timeout: Optional[int] = 600, **kwargs):
response = http_post(
url=f'{self.server_url}/call',
json_data={
'processor_id': self.processor_id,
'function': 'encode',
**{'add_generation_prompt': add_generation_prompt},
**kwargs
}
},
timeout=timeout
)
response.raise_for_status()
return response.json()["result"]
Expand Down Expand Up @@ -146,6 +147,33 @@ def mix_dataset(self, interleave = True):
return response.json()["result"]


def save_as(self, output_path: str, format: Optional[str] = None, batch_size: int = 1000, mode: str = 'immediate', **kwargs):
response = http_post(
url=f'{self.server_url}/call',
json_data={
'processor_id': self.processor_id,
'function': 'save_as',
**{'output_path': output_path, 'format': format, 'batch_size': batch_size, 'mode': mode},
**kwargs
}
)
response.raise_for_status()
return response.json()["result"]


def flush_save(self):
response = http_post(
url=f'{self.server_url}/call',
json_data={
'processor_id': self.processor_id,
'function': 'flush_save',
**{},
}
)
response.raise_for_status()
return response.json()["result"]


def __getitem__(self, idx):
response = http_post(
url=f'{self.server_url}/call',
Expand Down
7 changes: 4 additions & 3 deletions src/twinkle_client/dataset/lazy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# 2. Run: python client_tools/client_generator.py
# ============================================================================

from typing import Any, Callable, Dict, Type, Union
from typing import Any, Callable, Dict, Optional, Type, Union
from twinkle_client.http import http_post
from twinkle.dataset import Dataset
from twinkle.dataset import DatasetMeta
Expand Down Expand Up @@ -91,15 +91,16 @@ def mix_dataset(self, interleave = True):
return response.json()["result"]


def encode(self, add_generation_prompt: bool = False, **kwargs):
def encode(self, add_generation_prompt: bool = False, timeout: Optional[int] = 600, **kwargs):
response = http_post(
url=f'{self.server_url}/call',
json_data={
'processor_id': self.processor_id,
'function': 'encode',
**{'add_generation_prompt': add_generation_prompt},
**kwargs
}
},
timeout=timeout
)
response.raise_for_status()
return response.json()["result"]
Expand Down
4 changes: 2 additions & 2 deletions src/twinkle_client/http/http_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def http_post(
json_data: Optional[Dict[str, Any]] = {},
data: Optional[Any] = {},
additional_headers: Optional[Dict[str, str]] = {},
timeout: int = 600,
timeout: Optional[int] = 600,
) -> requests.Response:
"""
Send HTTP POST request with required headers.
Expand All @@ -130,7 +130,7 @@ def http_post(
json_data: JSON data to send in request body
data: Form data or raw data to send in request body
additional_headers: Additional headers to include
timeout: Request timeout in seconds
timeout: Request timeout in seconds; None disables the timeout.

Returns:
requests.Response object
Expand Down
Loading