Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow generate_data logger parameter to overwrite locally defined loggers #449

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/instructlab/sdg/blocks/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Local
from ..registry import BlockRegistry

logger = logging.getLogger(__name__)
logger = logging.getLogger()


# This is part of the public API.
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/blocks/filterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..registry import BlockRegistry
from .block import Block

logger = logging.getLogger(__name__)
logger = logging.getLogger()


# This is part of the public API.
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/blocks/iterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..registry import BlockRegistry
from .block import Block

logger = logging.getLogger(__name__)
logger = logging.getLogger()


# This is part of the public API.
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/blocks/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..registry import BlockRegistry, PromptRegistry
from .block import Block, BlockConfigParserError

logger = logging.getLogger(__name__)
logger = logging.getLogger()

DEFAULT_MAX_NUM_TOKENS = 4096

Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/blocks/utilblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..registry import BlockRegistry
from .block import Block

logger = logging.getLogger(__name__)
logger = logging.getLogger()


# This is part of the public API.
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# First Party
from instructlab.sdg.utils import pandas

logger = logging.getLogger(__name__)
logger = logging.getLogger()


class Checkpointer:
Expand Down
24 changes: 12 additions & 12 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# when |knowledge| << |skills|
MIN_UPSAMPLE_THRESHOLD = 0.03
ALLOWED_COLS = ["id", "messages", "metadata"]
logger = logging.getLogger(__name__)
LOGGER = logging.getLogger()


class DatasetListing(TypedDict):
Expand All @@ -40,7 +40,7 @@ def _adjust_train_sample_size(ds: Dataset, num_samples: int):
Return a dataset with num_samples random samples selected from the
original dataset.
"""
logger.info(f"Rebalancing dataset to have {num_samples} samples ...")
LOGGER.info(f"Rebalancing dataset to have {num_samples} samples ...")
df = ds.to_pandas()
df = df.sample(n=num_samples, random_state=42, replace=True)
return pandas.dataset_from_pandas_dataframe(df)
Expand Down Expand Up @@ -135,10 +135,10 @@ def _load_ds(self, path):
"""
if not os.path.isabs(path):
path = os.path.join(os.path.dirname(self.recipe_path), path)
logger.info(f"Loading dataset from {path} ...")
LOGGER.info(f"Loading dataset from {path} ...")
dataset = load_dataset("json", data_files=path, split="train")
logger.info(f"Dataset columns: {dataset.column_names}")
logger.info(f"Dataset loaded with {len(dataset)} samples")
LOGGER.info(f"Dataset columns: {dataset.column_names}")
LOGGER.info(f"Dataset loaded with {len(dataset)} samples")
return dataset

def _load_and_sample_datasets(self, num_proc):
Expand All @@ -161,7 +161,7 @@ def _create_mixed_dataset(self, num_proc):
concatenating all datasets in this recipe
"""
if not self.dataset_added:
logger.error("No dataset added to the recipe")
LOGGER.error("No dataset added to the recipe")

mixed_ds = self._load_and_sample_datasets(num_proc)
mixed_ds = concatenate_datasets(mixed_ds)
Expand Down Expand Up @@ -212,7 +212,7 @@ def save_mixed_dataset(self, output_path, num_proc):
"""
mixed_ds = self._create_mixed_dataset(num_proc)
mixed_ds.to_json(output_path, orient="records", lines=True)
logger.info(f"Mixed Dataset saved to {output_path}")
LOGGER.info(f"Mixed Dataset saved to {output_path}")


def _unescape(s):
Expand All @@ -235,7 +235,7 @@ def _get_question_hack(synth_example):

parts = synth_example["output"].split("?", 1)
if len(parts) != 2:
logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
LOGGER.warning(f"Failed to split generated q&a: {synth_example['output']}")
return parts[0].strip() + "?" if len(parts) == 2 else ""


Expand All @@ -251,7 +251,7 @@ def _get_response_hack(synth_example):

parts = synth_example["output"].split("?", 1)
if len(parts) != 2:
logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
LOGGER.warning(f"Failed to split generated q&a: {synth_example['output']}")
return parts[1].strip() if len(parts) == 2 else parts[0].strip()


Expand Down Expand Up @@ -333,7 +333,7 @@ def __pick_documents(rec, p):
selected_docs = [e for e in all_context if e != answer_document]
if len(selected_docs) > 0:
if len(selected_docs) < num_doc_in_context:
logger.debug(
LOGGER.debug(
f"Number of unique documents is {len(selected_docs)} which is less than {num_doc_in_context}. Using all the documents in the expanded context."
)
if random.uniform(0, 1) < p:
Expand All @@ -352,7 +352,7 @@ def __pick_documents(rec, p):
else selected_docs
)
else:
logger.warning(
LOGGER.warning(
"Only 1 unique document found. Disabling expanded context injection, which may lead to poorer knowledge retention results."
)
docs = [answer_document]
Expand Down Expand Up @@ -697,7 +697,7 @@ def collect(
if knowledge_to_skills_ratio < MIN_UPSAMPLE_THRESHOLD:
sampling_size = int(self._precomputed_skills_length * 0.03)

logger.info(
LOGGER.info(
"\033[93mKnowledge detected to be less than %.2f%% of skills (%.2f%%), upsampling to: %d\033[0m",
MIN_UPSAMPLE_THRESHOLD * 100,
knowledge_to_skills_ratio * 100,
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/eval_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# First Party
from instructlab.sdg.pipeline import EVAL_PIPELINES_PKG, Pipeline

logger = logging.getLogger(__name__)
logger = logging.getLogger()


def _extract_options(text: str) -> list[Any]:
Expand Down
23 changes: 11 additions & 12 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
read_taxonomy_leaf_nodes,
)

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger()

_SYS_PROMPT = "I am a Red Hat® Instruct Model, an AI language model developed by Red Hat and IBM Research based on the granite-3.0-8b-base model. My primary role is to serve as a chat assistant."

Expand Down Expand Up @@ -90,7 +90,7 @@ def _gen_train_data(

for output_dataset in machine_instruction_data:
for synth_example in output_dataset:
logger.debug(synth_example)
LOGGER.debug(synth_example)
user = _get_question_hack(synth_example)
if len(synth_example.get("context", "")) > 0:
user += "\n" + synth_example["context"]
Expand Down Expand Up @@ -223,7 +223,7 @@ def _sdg_init(ctx, pipeline):
config = yaml.safe_load(file)
docling_model_path = config["models"][0]["path"]
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
logger.warning(f"unable to read docling models path from config.yaml {e}")
LOGGER.warning(f"unable to read docling models path from config.yaml {e}")

for d in data_dirs:
pipeline_path = os.path.join(d, "pipelines", pipeline)
Expand Down Expand Up @@ -285,7 +285,6 @@ def _mixer_init(
# to be removed: logger
def generate_data(
client: openai.OpenAI,
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
system_prompt: Optional[str] = None,
use_legacy_pretraining_format: Optional[bool] = True,
model_family: Optional[str] = None,
Expand Down Expand Up @@ -352,7 +351,7 @@ def generate_data(
system_prompt,
)

logger.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}")
LOGGER.debug(f"Generating to: {os.path.join(output_dir, output_file_test)}")

model_family = models.get_model_family(model_family, model_name)

Expand Down Expand Up @@ -385,7 +384,7 @@ def generate_data(
)

if console_output:
logger.info(
LOGGER.info(
"Synthesizing new instructions. If you aren't satisfied with the generated instructions, interrupt training (Ctrl-C) and try adjusting your YAML files. Adding more examples may help."
)

Expand Down Expand Up @@ -417,17 +416,17 @@ def generate_data(
else:
pipe = freeform_skills_pipe

logger.debug("Samples: %s", samples)
LOGGER.debug("Samples: %s", samples)

new_generated_data = pipe.generate(samples, leaf_node_path)
if len(new_generated_data) == 0:
empty_sdg_leaf_nodes.append(leaf_node_path)
logger.warning("Empty dataset for qna node: %s", leaf_node_path)
LOGGER.warning("Empty dataset for qna node: %s", leaf_node_path)
continue
generated_data.append(new_generated_data)

logger.info("Generated %d samples", len(generated_data))
logger.debug("Generated data: %s", generated_data)
LOGGER.info("Generated %d samples", len(generated_data))
LOGGER.debug("Generated data: %s", generated_data)

if is_knowledge:
# generate mmlubench data for the current leaf node
Expand Down Expand Up @@ -456,9 +455,9 @@ def generate_data(
mixer.generate()

generate_duration = time.time() - generate_start
logger.info(f"Generation took {generate_duration:.2f}s")
LOGGER.info(f"Generation took {generate_duration:.2f}s")
if len(empty_sdg_leaf_nodes) > 0:
logger.warning(
LOGGER.warning(
"Leaf nodes with empty sdg output: {}".format(
" ".join(empty_sdg_leaf_nodes)
)
Expand Down
11 changes: 5 additions & 6 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .blocks.block import Block
from .registry import BlockRegistry

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger()


# This is part of the public API.
Expand Down Expand Up @@ -140,7 +140,6 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset:
dataset: the input dataset
checkpoint_name: unique subdir name for the checkpoint within checkpoint_dir
"""

# The checkpointer allows us to resume from where we left off
# Saving the output of pipe instances along the way
checkpoint_dir = None
Expand All @@ -153,12 +152,12 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset:

# If not batching, simply delegate to _generate_single
if not self.ctx.batching_enabled:
logger.info("Running pipeline single-threaded")
LOGGER.info("Running pipeline single-threaded")
return self._generate_single(dataset)

# Otherwise, split the dataset into batches and run each batch as a
# future in the thread pool
logger.info(
LOGGER.info(
"Running pipeline with multi-threaded batching. Using %s workers for batches of size %s",
self.ctx.batch_num_workers,
self.ctx.batch_size,
Expand Down Expand Up @@ -197,7 +196,7 @@ def _generate_single(self, dataset) -> Dataset:
drop_columns = block_prop.get("drop_columns", [])
drop_duplicates_cols = block_prop.get("drop_duplicates", False)
block = block_type(self.ctx, self, block_name, **block_config)
logger.info("Running block: %s", block_name)
LOGGER.info("Running block: %s", block_name)
# Execute the block and wrap errors with the block name/type
dataset = block.generate(dataset)
except Exception as err:
Expand Down Expand Up @@ -284,7 +283,7 @@ def _parse_pipeline_config_file(pipeline_yaml):
"The pipeline config file format is from a future major version."
)
if major <= _PIPELINE_CONFIG_PARSER_MAJOR and minor > _PIPELINE_CONFIG_PARSER_MINOR:
logger.warning(
LOGGER.warning(
"The pipeline config file may have new features that will be ignored."
)

Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Third Party
from jinja2 import Environment, StrictUndefined, Template

logger = logging.getLogger(__name__)
logger = logging.getLogger()


class BlockRegistry:
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/utils/chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# First Party
from instructlab.sdg.utils.model_formats import is_model_gguf, is_model_safetensors

logger = logging.getLogger(__name__)
logger = logging.getLogger()
_DEFAULT_CHUNK_OVERLAP = 100


Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/utils/model_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Third Party
from gguf.constants import GGUF_MAGIC

logger = logging.getLogger(__name__)
logger = logging.getLogger()


def is_model_safetensors(model_path: pathlib.Path) -> bool:
Expand Down
Loading
Loading