diff --git a/packages/langchain_huggingface/example/langchain_huggingface_example.dart b/packages/langchain_huggingface/example/langchain_huggingface_example.dart index 21f3e9f2..410dd202 100644 --- a/packages/langchain_huggingface/example/langchain_huggingface_example.dart +++ b/packages/langchain_huggingface/example/langchain_huggingface_example.dart @@ -1,3 +1,33 @@ -void main() { - // TODO +// ignore_for_file: avoid_print, unused_element + +import 'package:langchain_core/chat_models.dart'; +import 'package:langchain_core/prompts.dart'; + +import 'package:langchain_huggingface/src/llm/huggingface_inference.dart'; + +void main() async { + // Uncomment the example you want to run: + await _example1(); + await _example2(); +} + +/// The most basic building block of LangChain is calling an LLM on some input. +Future _example1() async { + final huggingFace = HuggingfaceInference.call( + model: 'gpt2', + apiKey: '....API_KEY...', + ); + final result = await huggingFace('Who are you?'); + print(result); +} + +Future _example2() async { + final huggingFace = HuggingfaceInference.call( + model: 'gpt2', + apiKey: '....API_KEY...', + ); + + final str = huggingFace.stream(PromptValue.string('Who are you?')); + + str.listen(print); } diff --git a/packages/langchain_huggingface/lib/langchain_huggingface.dart b/packages/langchain_huggingface/lib/langchain_huggingface.dart index 3f6c9ef0..6a0ab625 100644 --- a/packages/langchain_huggingface/lib/langchain_huggingface.dart +++ b/packages/langchain_huggingface/lib/langchain_huggingface.dart @@ -1,2 +1,4 @@ /// Hugging Face module for LangChain.dart. library; + +export 'src/llm/llm.dart'; diff --git a/packages/langchain_huggingface/lib/src/llm/huggingface_inference.dart b/packages/langchain_huggingface/lib/src/llm/huggingface_inference.dart new file mode 100644 index 00000000..87ba57a4 --- /dev/null +++ b/packages/langchain_huggingface/lib/src/llm/huggingface_inference.dart @@ -0,0 +1,79 @@ +import 'package:huggingface_client/huggingface_client.dart'; +import 'package:langchain_core/llms.dart'; +import 'package:langchain_core/src/prompts/types.dart'; +import 'package:meta/meta.dart'; +import '../../langchain_huggingface.dart'; +import 'mappers.dart'; +import 'types.dart'; + +@immutable +class HuggingfaceInference extends BaseLLM { + const HuggingfaceInference._({ + required this.model, + required this.apiKey, + required this.apiClient, + super.defaultOptions = const HuggingFaceOptions(), + }); + final InferenceApi apiClient; + final String apiKey; + final String model; + factory HuggingfaceInference.call({ + required String apiKey, + required String model, + }) { + final apiClient = InferenceApi(HuggingFaceClient.getInferenceClient( + apiKey, HuggingFaceClient.inferenceBasePath)); + return HuggingfaceInference._( + model: model, apiKey: apiKey, apiClient: apiClient); + } + @override + Future invoke(PromptValue input, + {HuggingFaceOptions? options}) async { + final parameters = ApiQueryNLPTextGeneration( + inputs: input.toString(), + temperature: options?.temperature ?? 1.0, + topK: options?.topK ?? 0, + topP: options?.topP ?? 0.0, + maxTime: options?.maxTime ?? -1.0, + returnFullText: options?.returnFullText ?? true, + repetitionPenalty: options?.repetitionPenalty ?? -1, + doSample: options?.doSample ?? true, + maxNewTokens: options?.maxNewTokens ?? -1, + options: InferenceOptions( + useCache: options?.useCache ?? true, + waitForModel: options?.waitForModel ?? false)); + final result = await apiClient.queryNLPTextGeneration( + taskParameters: parameters, model: model); + + return result![0]!.toLLMResult(); + } + + @override + Stream stream(PromptValue input, {HuggingFaceOptions? options}) { + final query = ApiQueryNLPTextGeneration( + inputs: input.toString(), + temperature: options?.temperature ?? 1.0, + topK: options?.topK ?? 0, + topP: options?.topP ?? 0.0, + maxTime: options?.maxTime ?? -1.0, + returnFullText: options?.returnFullText ?? true, + repetitionPenalty: options?.repetitionPenalty ?? -1, + doSample: options?.doSample ?? true, + maxNewTokens: options?.maxNewTokens ?? -1, + options: InferenceOptions( + useCache: options?.useCache ?? true, + waitForModel: options?.waitForModel ?? false)); + final stream = apiClient.textStreamGeneration(query: query, model: model); + + return stream.map((response) => response.toLLMResult()); + } + + @override + String get modelType => 'llm'; + @override + Future> tokenize(PromptValue promptValue, + {HuggingFaceOptions? options}) async { + // TODO: implement tokenize + throw UnimplementedError(); + } +} diff --git a/packages/langchain_huggingface/lib/src/llm/llm.dart b/packages/langchain_huggingface/lib/src/llm/llm.dart new file mode 100644 index 00000000..a5022ed2 --- /dev/null +++ b/packages/langchain_huggingface/lib/src/llm/llm.dart @@ -0,0 +1,2 @@ +export 'huggingface_inference.dart'; +export 'types.dart'; diff --git a/packages/langchain_huggingface/lib/src/llm/mappers.dart b/packages/langchain_huggingface/lib/src/llm/mappers.dart new file mode 100644 index 00000000..71c9ecfd --- /dev/null +++ b/packages/langchain_huggingface/lib/src/llm/mappers.dart @@ -0,0 +1,27 @@ +import 'package:huggingface_client/huggingface_client.dart'; +import 'package:langchain_core/language_models.dart'; +import 'package:langchain_core/llms.dart'; + +extension HuggingFaceResponseMapper on ApiResponseNLPTextGeneration { + //map to + LLMResult toLLMResult() { + return LLMResult( + id: 'id', + output: generatedText, + finishReason: FinishReason.unspecified, + metadata: {}, + usage: const LanguageModelUsage()); + } +} + +extension HuggingFaceStreamResponseMapper on TextGenerationStreamResponse { + //map to + LLMResult toLLMResult() { + return LLMResult( + id: id.toString(), + output: text, + finishReason: FinishReason.unspecified, + metadata: {}, + usage: const LanguageModelUsage()); + } +} diff --git a/packages/langchain_huggingface/lib/src/llm/types.dart b/packages/langchain_huggingface/lib/src/llm/types.dart new file mode 100644 index 00000000..c0b32a46 --- /dev/null +++ b/packages/langchain_huggingface/lib/src/llm/types.dart @@ -0,0 +1,95 @@ +import 'package:langchain_core/llms.dart'; +import 'package:langchain_core/src/language_models/types.dart'; +import 'package:meta/meta.dart'; + +@immutable +class HuggingFaceOptions extends LLMOptions { + const HuggingFaceOptions( + {this.topK, + this.topP, + super.model, + this.temperature, + this.repetitionPenalty, + this.maxNewTokens, + this.maxTime, + this.returnFullText, + this.numReturnSequences, + this.useCache, + this.waitForModel, + this.doSample}); + + /// (Default: true). Boolean. There is a cache layer on the inference API to speedup requests we have already seen. + /// Most models can use those results as is as models are deterministic (meaning the results will be the same anyway). + /// However if you use a non deterministic model, you can set this parameter to prevent the caching mechanism from being + /// used resulting in a real new query. + final bool? useCache; + + /// (Default: false) Boolean. If the model is not ready, wait for it instead of receiving 503. It limits the number of requests + /// required to get your inference done. It is advised to only set this flag to true after receiving a 503 + /// error as it will limit hanging in your application to known places. + final bool? waitForModel; + + /// (Default: None). Integer to define the top tokens considered within the sample operation to create new text. + final int? topK; + + /// (Default: None). Float to define the tokens that are within the sample operation of text generation. + /// Add tokens in the sample for more probable to least probable until the sum of the probabilities + /// is greater than top_p. + final double? topP; + + /// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + /// 0 means always take the highest score, 100.0 is getting closer to uniform probability. + final double? temperature; + + /// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + /// to not be picked in successive generation passes. + final double? repetitionPenalty; + + /// (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input + /// length it is a estimate of the size of generated text you want. Each new tokens slows down the request, + /// so look for balance between response times and length of text generated. + final int? maxNewTokens; + + /// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + /// Network can cause some overhead so it will be a soft limit. Use that in combination + /// with [maxNewTokens] for best results. + final double? maxTime; + + /// (Default: True). Bool. If set to False, the return results will not contain the + /// original query making it easier for prompting. + final bool? returnFullText; + + /// (Default: 1). Integer. The number of proposition you want to be returned. + final int? numReturnSequences; + + /// (Optional: True). Bool. Whether or not to use sampling, use greedy + /// decoding otherwise + final bool? doSample; + + @override + HuggingFaceOptions copyWith( + {final String? model, + final int? concurrencyLimit, + final int? topK, + final double? topP, + final double? temperature, + final double? repetitionPenalty, + final int? maxNewTokens, + final double? maxTime, + final bool? returnFullText, + final int? numReturnSequences, + final bool? doSample}) { + return HuggingFaceOptions( + model: model ?? this.model, + repetitionPenalty: repetitionPenalty ?? this.repetitionPenalty, + returnFullText: returnFullText ?? this.returnFullText, + numReturnSequences: numReturnSequences ?? this.numReturnSequences, + doSample: doSample ?? this.doSample, + topK: topK ?? this.topK, + temperature: temperature ?? this.temperature, + topP: topP ?? this.topP, + maxTime: maxTime ?? this.maxTime, + maxNewTokens: maxNewTokens ?? this.maxNewTokens, + ); + } +} diff --git a/packages/langchain_huggingface/pubspec.yaml b/packages/langchain_huggingface/pubspec.yaml index 2f29e62b..71447571 100644 --- a/packages/langchain_huggingface/pubspec.yaml +++ b/packages/langchain_huggingface/pubspec.yaml @@ -15,3 +15,6 @@ topics: environment: sdk: ">=3.4.0 <4.0.0" +dependencies: + huggingface_client: ^1.6.0 + langchain_core: ^0.3.6