lordgamez commented on code in PR #1903: URL: https://github.com/apache/nifi-minifi-cpp/pull/1903#discussion_r2039161304
########## METRICS.md: ########## @@ -288,3 +289,18 @@ Processor level metric that reports metrics for the GetFile processor if defined | metric_class | Class name to filter for this metric, set to GetFileMetrics | | processor_name | Name of the processor | | processor_uuid | UUID of the processor | + +### RunLlamaCppInferenceMetrics + +Processor level metric that reports metrics for the RunLlamaCppInference processor if defined in the flow configuration. + +| Metric name | Labels | Description | +|-----------------------|----------------------------------------------|----------------------------------------------------------------------------| +| tokens_in | metric_class, processor_name, processor_uuid | Number of tokens parsed from the input prompts in the processor's lifetime | +| tokens_out | metric_class, processor_name, processor_uuid | Number of tokens generated in the completion in the processor's lifetime | + +| Label | Description | +|----------------|----------------------------------------------------------------| +| metric_class | Class name to filter for this metric, set to GetFileMetrics | Review Comment: Good catch, updated in https://github.com/apache/nifi-minifi-cpp/pull/1903/commits/e93db2c01d6e2e159b287add948f5386319b21b6 ########## extensions/llamacpp/processors/DefaultLlamaContext.cpp: ########## @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "DefaultLlamaContext.h" +#include "Exception.h" +#include "fmt/format.h" +#include "utils/ConfigurationUtils.h" + +namespace org::apache::nifi::minifi::extensions::llamacpp::processors { + +namespace { +std::vector<llama_token> tokenizeInput(const llama_vocab* vocab, const std::string& input) { + int32_t number_of_tokens = gsl::narrow<int32_t>(input.length()) + 2; + std::vector<llama_token> tokenized_input(number_of_tokens); + number_of_tokens = llama_tokenize(vocab, input.data(), gsl::narrow<int32_t>(input.length()), tokenized_input.data(), gsl::narrow<int32_t>(tokenized_input.size()), true, true); + if (number_of_tokens < 0) { + tokenized_input.resize(-number_of_tokens); + [[maybe_unused]] int32_t check = llama_tokenize(vocab, input.data(), gsl::narrow<int32_t>(input.length()), tokenized_input.data(), gsl::narrow<int32_t>(tokenized_input.size()), true, true); + gsl_Assert(check == -number_of_tokens); + } else { + tokenized_input.resize(number_of_tokens); + } + return tokenized_input; +} +} // namespace + + +DefaultLlamaContext::DefaultLlamaContext(const std::filesystem::path& model_path, const LlamaSamplerParams& llama_sampler_params, const LlamaContextParams& llama_ctx_params) { + llama_backend_init(); + + llama_model_ = llama_model_load_from_file(model_path.string().c_str(), llama_model_default_params()); // NOLINT(cppcoreguidelines-prefer-member-initializer) + if (!llama_model_) { + throw Exception(ExceptionType::PROCESS_SCHEDULE_EXCEPTION, fmt::format("Failed to load model from '{}'", model_path.string())); + } + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = llama_ctx_params.n_ctx; + ctx_params.n_batch = llama_ctx_params.n_batch; + ctx_params.n_ubatch = llama_ctx_params.n_ubatch; + ctx_params.n_seq_max = llama_ctx_params.n_seq_max; + ctx_params.n_threads = llama_ctx_params.n_threads; + ctx_params.n_threads_batch = llama_ctx_params.n_threads_batch; + ctx_params.flash_attn = false; + llama_ctx_ = llama_init_from_model(llama_model_, ctx_params); + + auto sparams = llama_sampler_chain_default_params(); + llama_sampler_ = llama_sampler_chain_init(sparams); + + if (llama_sampler_params.min_p) { + llama_sampler_chain_add(llama_sampler_, llama_sampler_init_min_p(*llama_sampler_params.min_p, llama_sampler_params.min_keep)); + } + if (llama_sampler_params.top_k) { + llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_k(*llama_sampler_params.top_k)); + } + if (llama_sampler_params.top_p) { + llama_sampler_chain_add(llama_sampler_, llama_sampler_init_top_p(*llama_sampler_params.top_p, llama_sampler_params.min_keep)); + } + if (llama_sampler_params.temperature) { + llama_sampler_chain_add(llama_sampler_, llama_sampler_init_temp(*llama_sampler_params.temperature)); + } + llama_sampler_chain_add(llama_sampler_, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); +} + +DefaultLlamaContext::~DefaultLlamaContext() { + llama_sampler_free(llama_sampler_); + llama_sampler_ = nullptr; + llama_free(llama_ctx_); + llama_ctx_ = nullptr; + llama_model_free(llama_model_); + llama_model_ = nullptr; + llama_backend_free(); +} + +std::optional<std::string> DefaultLlamaContext::applyTemplate(const std::vector<LlamaChatMessage>& messages) { + std::vector<llama_chat_message> llama_messages; + llama_messages.reserve(messages.size()); + for (auto& msg : messages) { + llama_messages.push_back(llama_chat_message{.role = msg.role.c_str(), .content = msg.content.c_str()}); + } + std::string text; + text.resize(utils::configuration::DEFAULT_BUFFER_SIZE); + const char * chat_template = llama_model_chat_template(llama_model_, nullptr); + int32_t res_size = llama_chat_apply_template(chat_template, llama_messages.data(), llama_messages.size(), true, text.data(), gsl::narrow<int32_t>(text.size())); + if (res_size < 0) { + return std::nullopt; + } + if (res_size > gsl::narrow<int32_t>(text.size())) { + text.resize(res_size); + res_size = llama_chat_apply_template(chat_template, llama_messages.data(), llama_messages.size(), true, text.data(), gsl::narrow<int32_t>(text.size())); + if (res_size < 0) { + return std::nullopt; + } + } + text.resize(res_size); + + return text; +} + +nonstd::expected<GenerationResult, std::string> DefaultLlamaContext::generate(const std::string& input, std::function<void(std::string_view/*token*/)> token_handler) { + GenerationResult result{}; + auto start_time = std::chrono::steady_clock::now(); + const llama_vocab * vocab = llama_model_get_vocab(llama_model_); + std::vector<llama_token> tokenized_input = tokenizeInput(vocab, input); + result.num_tokens_in = gsl::narrow<uint64_t>(tokenized_input.size()); + + llama_batch batch = llama_batch_get_one(tokenized_input.data(), gsl::narrow<int32_t>(tokenized_input.size())); + llama_token new_token_id = 0; + bool first_token_generated = false; + while (true) { + int32_t res = llama_decode(llama_ctx_, batch); + if (res == 1) { + return nonstd::make_unexpected("Could not find a KV slot for the batch (try reducing the size of the batch or increase the context)"); + } else if (res < 0) { + return nonstd::make_unexpected("Error occurred while executing llama decode"); + } + + new_token_id = llama_sampler_sample(llama_sampler_, llama_ctx_, -1); + if (!first_token_generated) { + result.time_to_first_token = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start_time); + first_token_generated = true; + } + + if (llama_vocab_is_eog(vocab, new_token_id)) { + break; + } + + ++result.num_tokens_out; + llama_sampler_accept(llama_sampler_, new_token_id); + + std::array<char, 128> buf{}; + int32_t len = llama_token_to_piece(vocab, new_token_id, buf.data(), gsl::narrow<int32_t>(buf.size()), 0, true); + if (len < 0) { + return nonstd::make_unexpected("Failed to convert token to text"); + } + gsl_Assert(len < 128); + + std::string_view token_str{buf.data(), gsl::narrow<std::string_view::size_type>(len)}; + batch = llama_batch_get_one(&new_token_id, 1); + token_handler(token_str); + } + + result.tokens_per_second = + gsl::narrow<double>(result.num_tokens_out) / gsl::narrow<double>(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start_time).count()) / 1000.0; Review Comment: Good catch, fixed in https://github.com/apache/nifi-minifi-cpp/pull/1903/commits/e93db2c01d6e2e159b287add948f5386319b21b6 ########## extensions/llamacpp/processors/RunLlamaCppInference.cpp: ########## @@ -0,0 +1,164 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "RunLlamaCppInference.h" +#include "core/ProcessContext.h" +#include "core/ProcessSession.h" +#include "core/Resource.h" +#include "Exception.h" + +#include "rapidjson/document.h" +#include "rapidjson/error/en.h" +#include "LlamaContext.h" +#include "utils/ProcessorConfigUtils.h" + +namespace org::apache::nifi::minifi::extensions::llamacpp::processors { + +void RunLlamaCppInference::initialize() { + setSupportedProperties(Properties); + setSupportedRelationships(Relationships); +} + +void RunLlamaCppInference::onSchedule(core::ProcessContext& context, core::ProcessSessionFactory&) { + model_path_.clear(); + model_path_ = utils::parseProperty(context, ModelPath); + system_prompt_ = context.getProperty(SystemPrompt).value_or(""); + + LlamaSamplerParams llama_sampler_params; + llama_sampler_params.temperature = utils::parseOptionalFloatProperty(context, Temperature); + if (auto top_k = utils::parseOptionalI64Property(context, TopK)) { + llama_sampler_params.top_k = gsl::narrow<int32_t>(*top_k); + } + llama_sampler_params.top_p = utils::parseOptionalFloatProperty(context, TopP); + llama_sampler_params.min_p = utils::parseOptionalFloatProperty(context, MinP); + llama_sampler_params.min_keep = utils::parseU64Property(context, MinKeep); + + LlamaContextParams llama_ctx_params; + llama_ctx_params.n_ctx = gsl::narrow<uint32_t>(utils::parseU64Property(context, TextContextSize)); + llama_ctx_params.n_batch = gsl::narrow<uint32_t>(utils::parseU64Property(context, LogicalMaximumBatchSize)); + llama_ctx_params.n_ubatch = gsl::narrow<uint32_t>(utils::parseU64Property(context, PhysicalMaximumBatchSize)); + llama_ctx_params.n_seq_max = gsl::narrow<uint32_t>(utils::parseU64Property(context, MaxNumberOfSequences)); + llama_ctx_params.n_threads = gsl::narrow<int32_t>(utils::parseI64Property(context, ThreadsForGeneration)); + llama_ctx_params.n_threads_batch = gsl::narrow<int32_t>(utils::parseI64Property(context, ThreadsForBatchProcessing)); + + llama_ctx_ = LlamaContext::create(model_path_, llama_sampler_params, llama_ctx_params); +} + +void RunLlamaCppInference::increaseTokensIn(uint64_t token_count) { + auto* const llamacpp_metrics = dynamic_cast<RunLlamaCppInferenceMetrics*>(metrics_.get()); + gsl_Assert(llamacpp_metrics); + std::lock_guard<std::mutex> lock(llamacpp_metrics->tokens_in_mutex_); + if (llamacpp_metrics->tokens_in > std::numeric_limits<uint64_t>::max() - token_count) { + logger_->log_warn("Tokens in count overflow detected, resetting to 0"); + llamacpp_metrics->tokens_in = token_count; + return; + } + + llamacpp_metrics->tokens_in += token_count; +} + +void RunLlamaCppInference::increaseTokensOut(uint64_t token_count) { + auto* const llamacpp_metrics = dynamic_cast<RunLlamaCppInferenceMetrics*>(metrics_.get()); + gsl_Assert(llamacpp_metrics); + std::lock_guard<std::mutex> lock(llamacpp_metrics->tokens_out_mutex_); + if (llamacpp_metrics->tokens_out > std::numeric_limits<uint64_t>::max() - token_count) { + logger_->log_warn("Tokens out count overflow detected, resetting to 0"); + llamacpp_metrics->tokens_out = token_count; + return; + } + + llamacpp_metrics->tokens_out += token_count; +} + +void RunLlamaCppInference::onTrigger(core::ProcessContext& context, core::ProcessSession& session) { + auto flow_file = session.get(); + if (!flow_file) { + context.yield(); + return; + } + + auto prompt = context.getProperty(Prompt, flow_file.get()).value_or(""); + + auto read_result = session.readBuffer(flow_file); + std::string input_data_and_prompt; + if (!read_result.buffer.empty()) { + input_data_and_prompt.append("Input data (or flow file content):\n"); + input_data_and_prompt.append({reinterpret_cast<const char*>(read_result.buffer.data()), read_result.buffer.size()}); + input_data_and_prompt.append("\n\n"); + } + input_data_and_prompt.append(prompt); + + if (input_data_and_prompt.empty()) { + logger_->log_error("Input data and prompt are empty"); + session.transfer(flow_file, Failure); + return; + } + + auto input = [&] { + std::vector<LlamaChatMessage> messages; + if (!system_prompt_.empty()) { + messages.push_back({.role = "system", .content = system_prompt_}); + } + messages.push_back({.role = "user", .content = input_data_and_prompt}); + + return llama_ctx_->applyTemplate(messages); + }(); + + if (!input) { + logger_->log_error("Inference failed with while applying template"); + session.transfer(flow_file, Failure); + return; + } + + logger_->log_debug("AI model input: {}", *input); + + auto start_time = std::chrono::steady_clock::now(); + + std::string text; + auto generation_result = llama_ctx_->generate(*input, [&] (std::string_view token) { + text += token; + }); + + auto elapsed_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start_time).count(); + + if (!generation_result) { + logger_->log_error("Inference failed with generation error: '{}'", generation_result.error()); + session.transfer(flow_file, Failure); + return; + } + + increaseTokensIn(generation_result->num_tokens_in); + increaseTokensOut(generation_result->num_tokens_out); + + logger_->log_debug("Number of tokens generated: {}", generation_result->num_tokens_out); + logger_->log_debug("AI model inference time: {} ms", elapsed_time); + logger_->log_debug("AI model output: {}", text); + + flow_file->setAttribute(LlamaCppTimeToFirstToken.name, std::to_string(generation_result->time_to_first_token.count()) + " ms"); + flow_file->setAttribute(LlamaCppTokensPerSecond.name, std::to_string(generation_result->tokens_per_second)); Review Comment: I think at least 2 decimal points should be used and fixed to be consistent, updated in https://github.com/apache/nifi-minifi-cpp/pull/1903/commits/e93db2c01d6e2e159b287add948f5386319b21b6 ########## extensions/llamacpp/tests/RunLlamaCppInferenceTests.cpp: ########## @@ -0,0 +1,351 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "unit/TestBase.h" +#include "unit/Catch.h" +#include "RunLlamaCppInference.h" +#include "unit/SingleProcessorTestController.h" +#include "core/FlowFile.h" + +namespace org::apache::nifi::minifi::extensions::llamacpp::test { + +class MockLlamaContext : public processors::LlamaContext { + public: + std::optional<std::string> applyTemplate(const std::vector<processors::LlamaChatMessage>& messages) override { + if (fail_apply_template_) { + return std::nullopt; + } + messages_ = messages; + return "Test input"; + } + + nonstd::expected<processors::GenerationResult, std::string> generate(const std::string& input, std::function<void(std::string_view/*token*/)> token_handler) override { + if (fail_generation_) { + return nonstd::make_unexpected("Generation failed"); + } + processors::GenerationResult result; + input_ = input; + token_handler("Test "); + token_handler("generated"); + token_handler(" content"); + result.time_to_first_token = std::chrono::milliseconds(100); + result.num_tokens_in = 10; + result.num_tokens_out = 3; + result.tokens_per_second = 2.0; + return result; + } + + [[nodiscard]] const std::vector<processors::LlamaChatMessage>& getMessages() const { + return messages_; + } + + [[nodiscard]] const std::string& getInput() const { + return input_; + } + + void setGenerationFailure() { + fail_generation_ = true; + } + + void setApplyTemplateFailure() { + fail_apply_template_ = true; + } + + private: + bool fail_generation_{false}; + bool fail_apply_template_{false}; + std::vector<processors::LlamaChatMessage> messages_; + std::string input_; +}; + +TEST_CASE("Prompt is generated correctly with default parameters") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + auto mock_llama_context_ptr = mock_llama_context.get(); + std::filesystem::path test_model_path; + processors::LlamaSamplerParams test_sampler_params; + processors::LlamaContextParams test_context_params; + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path& model_path, const processors::LlamaSamplerParams& sampler_params, const processors::LlamaContextParams& context_params) { + test_model_path = model_path; + test_sampler_params = sampler_params; + test_context_params = context_params; + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + CHECK(test_model_path == "Dummy model"); + CHECK(test_sampler_params.temperature == 0.8F); + CHECK(test_sampler_params.top_k == 40); + CHECK(test_sampler_params.top_p == 0.9F); + CHECK(test_sampler_params.min_p == std::nullopt); + CHECK(test_sampler_params.min_keep == 0); + CHECK(test_context_params.n_ctx == 4096); + CHECK(test_context_params.n_batch == 2048); + CHECK(test_context_params.n_ubatch == 512); + CHECK(test_context_params.n_seq_max == 1); + CHECK(test_context_params.n_threads == 4); + CHECK(test_context_params.n_threads_batch == 4); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0]; + CHECK(*output_flow_file->getAttribute(processors::RunLlamaCppInference::LlamaCppTimeToFirstToken.name) == "100 ms"); + CHECK(*output_flow_file->getAttribute(processors::RunLlamaCppInference::LlamaCppTokensPerSecond.name) == "2.000000"); + CHECK(controller.plan->getContent(output_flow_file) == "Test generated content"); + CHECK(mock_llama_context_ptr->getInput() == "Test input"); + REQUIRE(mock_llama_context_ptr->getMessages().size() == 2); + CHECK(mock_llama_context_ptr->getMessages()[0].role == "system"); + CHECK(mock_llama_context_ptr->getMessages()[0].content == "You are a helpful assistant. You are given a question with some possible input data otherwise called flow file content. " + "You are expected to generate a response based on the question and the input data."); + CHECK(mock_llama_context_ptr->getMessages()[1].role == "user"); + CHECK(mock_llama_context_ptr->getMessages()[1].content == "Input data (or flow file content):\n42\n\nQuestion: What is the answer to life, the universe and everything?"); +} + +TEST_CASE("Prompt is generated correctly with custom parameters") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + auto mock_llama_context_ptr = mock_llama_context.get(); + std::filesystem::path test_model_path; + processors::LlamaSamplerParams test_sampler_params; + processors::LlamaContextParams test_context_params; + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path& model_path, const processors::LlamaSamplerParams& sampler_params, const processors::LlamaContextParams& context_params) { + test_model_path = model_path; + test_sampler_params = sampler_params; + test_context_params = context_params; + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "/path/to/model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Temperature.name, "0.4"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK.name, "20"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopP.name, ""); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinP.name, "0.1"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinKeep.name, "1"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TextContextSize.name, "4096"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::LogicalMaximumBatchSize.name, "1024"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::PhysicalMaximumBatchSize.name, "796"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MaxNumberOfSequences.name, "2"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ThreadsForGeneration.name, "12"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ThreadsForBatchProcessing.name, "8"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::SystemPrompt.name, "Whatever"); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + CHECK(test_model_path == "/path/to/model"); + CHECK(test_sampler_params.temperature == 0.4F); + CHECK(test_sampler_params.top_k == 20); + CHECK(test_sampler_params.top_p == std::nullopt); + CHECK(test_sampler_params.min_p == 0.1F); + CHECK(test_sampler_params.min_keep == 1); + CHECK(test_context_params.n_ctx == 4096); + CHECK(test_context_params.n_batch == 1024); + CHECK(test_context_params.n_ubatch == 796); + CHECK(test_context_params.n_seq_max == 2); + CHECK(test_context_params.n_threads == 12); + CHECK(test_context_params.n_threads_batch == 8); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0]; + CHECK(controller.plan->getContent(output_flow_file) == "Test generated content"); + CHECK(mock_llama_context_ptr->getInput() == "Test input"); + REQUIRE(mock_llama_context_ptr->getMessages().size() == 2); + CHECK(mock_llama_context_ptr->getMessages()[0].role == "system"); + CHECK(mock_llama_context_ptr->getMessages()[0].content == "Whatever"); + CHECK(mock_llama_context_ptr->getMessages()[1].role == "user"); + CHECK(mock_llama_context_ptr->getMessages()[1].content == "Input data (or flow file content):\n42\n\nQuestion: What is the answer to life, the universe and everything?"); +} + +TEST_CASE("Empty flow file does not include input data in prompt") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + auto mock_llama_context_ptr = mock_llama_context.get(); + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0]; + CHECK(controller.plan->getContent(output_flow_file) == "Test generated content"); + CHECK(mock_llama_context_ptr->getInput() == "Test input"); + REQUIRE(mock_llama_context_ptr->getMessages().size() == 2); + CHECK(mock_llama_context_ptr->getMessages()[0].role == "system"); + CHECK(mock_llama_context_ptr->getMessages()[0].content == "You are a helpful assistant. You are given a question with some possible input data otherwise called flow file content. " + "You are expected to generate a response based on the question and the input data."); + CHECK(mock_llama_context_ptr->getMessages()[1].role == "user"); + CHECK(mock_llama_context_ptr->getMessages()[1].content == "Question: What is the answer to life, the universe and everything?"); +} + +TEST_CASE("Invalid values for optional double type properties throw exception") { + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::make_unique<MockLlamaContext>(); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + std::string property_name; + SECTION("Invalid value for Temperature property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Temperature.name, "invalid_value"); + property_name = processors::RunLlamaCppInference::Temperature.name; + } + SECTION("Invalid value for Top P property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopP.name, "invalid_value"); + property_name = processors::RunLlamaCppInference::TopP.name; + } + SECTION("Invalid value for Min P property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinP.name, "invalid_value"); + property_name = processors::RunLlamaCppInference::MinP.name; + } + + REQUIRE_THROWS_WITH(controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}), + fmt::format("Expected parsable float from RunLlamaCppInference::{}: parsing error: GeneralParsingError (0)", property_name)); +} + +TEST_CASE("Top K property empty and invalid values are handled properly") { + std::optional<int32_t> test_top_k = 0; + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams& sampler_params, const processors::LlamaContextParams&) { + test_top_k = sampler_params.top_k; + return std::make_unique<MockLlamaContext>(); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + SECTION("Empty value for Top K property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK.name, ""); + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + REQUIRE(test_top_k == std::nullopt); + } + SECTION("Invalid value for Top K property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK.name, "invalid_value"); + REQUIRE_THROWS_WITH(controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}), + "Expected parsable int64_t from RunLlamaCppInference::Top K: parsing error: GeneralParsingError (0)"); + } +} + +TEST_CASE("Error handling during generation and applying template") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + + SECTION("Generation fails with error") { + mock_llama_context->setGenerationFailure(); + } + + SECTION("Applying template fails with error") { + mock_llama_context->setApplyTemplateFailure(); + } + + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "/path/to/model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).empty()); + REQUIRE(results.at(processors::RunLlamaCppInference::Failure).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Failure)[0]; + CHECK(controller.plan->getContent(output_flow_file) == "42"); +} + +TEST_CASE("Route flow file to failure when prompt and input data is empty") { + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::make_unique<MockLlamaContext>(); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "/path/to/model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, ""); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).empty()); + REQUIRE(results.at(processors::RunLlamaCppInference::Failure).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Failure)[0]; + CHECK(controller.plan->getContent(output_flow_file).empty()); +} + +TEST_CASE("System prompt is optional") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + auto mock_llama_context_ptr = mock_llama_context.get(); + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::SystemPrompt.name, ""); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0]; + CHECK(controller.plan->getContent(output_flow_file) == "Test generated content"); + CHECK(mock_llama_context_ptr->getInput() == "Test input"); + REQUIRE(mock_llama_context_ptr->getMessages().size() == 1); + CHECK(mock_llama_context_ptr->getMessages()[0].role == "user"); + CHECK(mock_llama_context_ptr->getMessages()[0].content == "Input data (or flow file content):\n42\n\nQuestion: What is the answer to life, the universe and everything?"); +} + +TEST_CASE("Test output metrics") { + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::make_unique<MockLlamaContext>(); + }); + auto processor = std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference"); + auto processor_metrics = processor->getMetrics(); + minifi::test::SingleProcessorTestController controller(std::move(processor)); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto prometheus_metrics = processor_metrics->calculateMetrics(); + CHECK(prometheus_metrics[prometheus_metrics.size() - 2].name == "tokens_in"); Review Comment: Updated in https://github.com/apache/nifi-minifi-cpp/pull/1903/commits/e93db2c01d6e2e159b287add948f5386319b21b6 ########## extensions/llamacpp/tests/RunLlamaCppInferenceTests.cpp: ########## @@ -0,0 +1,351 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "unit/TestBase.h" +#include "unit/Catch.h" +#include "RunLlamaCppInference.h" +#include "unit/SingleProcessorTestController.h" +#include "core/FlowFile.h" + +namespace org::apache::nifi::minifi::extensions::llamacpp::test { + +class MockLlamaContext : public processors::LlamaContext { + public: + std::optional<std::string> applyTemplate(const std::vector<processors::LlamaChatMessage>& messages) override { + if (fail_apply_template_) { + return std::nullopt; + } + messages_ = messages; + return "Test input"; + } + + nonstd::expected<processors::GenerationResult, std::string> generate(const std::string& input, std::function<void(std::string_view/*token*/)> token_handler) override { + if (fail_generation_) { + return nonstd::make_unexpected("Generation failed"); + } + processors::GenerationResult result; + input_ = input; + token_handler("Test "); + token_handler("generated"); + token_handler(" content"); + result.time_to_first_token = std::chrono::milliseconds(100); + result.num_tokens_in = 10; + result.num_tokens_out = 3; + result.tokens_per_second = 2.0; + return result; + } + + [[nodiscard]] const std::vector<processors::LlamaChatMessage>& getMessages() const { + return messages_; + } + + [[nodiscard]] const std::string& getInput() const { + return input_; + } + + void setGenerationFailure() { + fail_generation_ = true; + } + + void setApplyTemplateFailure() { + fail_apply_template_ = true; + } + + private: + bool fail_generation_{false}; + bool fail_apply_template_{false}; + std::vector<processors::LlamaChatMessage> messages_; + std::string input_; +}; + +TEST_CASE("Prompt is generated correctly with default parameters") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + auto mock_llama_context_ptr = mock_llama_context.get(); + std::filesystem::path test_model_path; + processors::LlamaSamplerParams test_sampler_params; + processors::LlamaContextParams test_context_params; + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path& model_path, const processors::LlamaSamplerParams& sampler_params, const processors::LlamaContextParams& context_params) { + test_model_path = model_path; + test_sampler_params = sampler_params; + test_context_params = context_params; + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + CHECK(test_model_path == "Dummy model"); + CHECK(test_sampler_params.temperature == 0.8F); + CHECK(test_sampler_params.top_k == 40); + CHECK(test_sampler_params.top_p == 0.9F); + CHECK(test_sampler_params.min_p == std::nullopt); + CHECK(test_sampler_params.min_keep == 0); + CHECK(test_context_params.n_ctx == 4096); + CHECK(test_context_params.n_batch == 2048); + CHECK(test_context_params.n_ubatch == 512); + CHECK(test_context_params.n_seq_max == 1); + CHECK(test_context_params.n_threads == 4); + CHECK(test_context_params.n_threads_batch == 4); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0]; + CHECK(*output_flow_file->getAttribute(processors::RunLlamaCppInference::LlamaCppTimeToFirstToken.name) == "100 ms"); + CHECK(*output_flow_file->getAttribute(processors::RunLlamaCppInference::LlamaCppTokensPerSecond.name) == "2.000000"); + CHECK(controller.plan->getContent(output_flow_file) == "Test generated content"); + CHECK(mock_llama_context_ptr->getInput() == "Test input"); + REQUIRE(mock_llama_context_ptr->getMessages().size() == 2); + CHECK(mock_llama_context_ptr->getMessages()[0].role == "system"); + CHECK(mock_llama_context_ptr->getMessages()[0].content == "You are a helpful assistant. You are given a question with some possible input data otherwise called flow file content. " + "You are expected to generate a response based on the question and the input data."); + CHECK(mock_llama_context_ptr->getMessages()[1].role == "user"); + CHECK(mock_llama_context_ptr->getMessages()[1].content == "Input data (or flow file content):\n42\n\nQuestion: What is the answer to life, the universe and everything?"); +} + +TEST_CASE("Prompt is generated correctly with custom parameters") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + auto mock_llama_context_ptr = mock_llama_context.get(); + std::filesystem::path test_model_path; + processors::LlamaSamplerParams test_sampler_params; + processors::LlamaContextParams test_context_params; + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path& model_path, const processors::LlamaSamplerParams& sampler_params, const processors::LlamaContextParams& context_params) { + test_model_path = model_path; + test_sampler_params = sampler_params; + test_context_params = context_params; + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "/path/to/model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Temperature.name, "0.4"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK.name, "20"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopP.name, ""); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinP.name, "0.1"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinKeep.name, "1"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TextContextSize.name, "4096"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::LogicalMaximumBatchSize.name, "1024"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::PhysicalMaximumBatchSize.name, "796"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MaxNumberOfSequences.name, "2"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ThreadsForGeneration.name, "12"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ThreadsForBatchProcessing.name, "8"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::SystemPrompt.name, "Whatever"); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + CHECK(test_model_path == "/path/to/model"); + CHECK(test_sampler_params.temperature == 0.4F); + CHECK(test_sampler_params.top_k == 20); + CHECK(test_sampler_params.top_p == std::nullopt); + CHECK(test_sampler_params.min_p == 0.1F); + CHECK(test_sampler_params.min_keep == 1); + CHECK(test_context_params.n_ctx == 4096); + CHECK(test_context_params.n_batch == 1024); + CHECK(test_context_params.n_ubatch == 796); + CHECK(test_context_params.n_seq_max == 2); + CHECK(test_context_params.n_threads == 12); + CHECK(test_context_params.n_threads_batch == 8); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0]; + CHECK(controller.plan->getContent(output_flow_file) == "Test generated content"); + CHECK(mock_llama_context_ptr->getInput() == "Test input"); + REQUIRE(mock_llama_context_ptr->getMessages().size() == 2); + CHECK(mock_llama_context_ptr->getMessages()[0].role == "system"); + CHECK(mock_llama_context_ptr->getMessages()[0].content == "Whatever"); + CHECK(mock_llama_context_ptr->getMessages()[1].role == "user"); + CHECK(mock_llama_context_ptr->getMessages()[1].content == "Input data (or flow file content):\n42\n\nQuestion: What is the answer to life, the universe and everything?"); +} + +TEST_CASE("Empty flow file does not include input data in prompt") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + auto mock_llama_context_ptr = mock_llama_context.get(); + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0]; + CHECK(controller.plan->getContent(output_flow_file) == "Test generated content"); + CHECK(mock_llama_context_ptr->getInput() == "Test input"); + REQUIRE(mock_llama_context_ptr->getMessages().size() == 2); + CHECK(mock_llama_context_ptr->getMessages()[0].role == "system"); + CHECK(mock_llama_context_ptr->getMessages()[0].content == "You are a helpful assistant. You are given a question with some possible input data otherwise called flow file content. " + "You are expected to generate a response based on the question and the input data."); + CHECK(mock_llama_context_ptr->getMessages()[1].role == "user"); + CHECK(mock_llama_context_ptr->getMessages()[1].content == "Question: What is the answer to life, the universe and everything?"); +} + +TEST_CASE("Invalid values for optional double type properties throw exception") { + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::make_unique<MockLlamaContext>(); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + std::string property_name; + SECTION("Invalid value for Temperature property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Temperature.name, "invalid_value"); + property_name = processors::RunLlamaCppInference::Temperature.name; + } + SECTION("Invalid value for Top P property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopP.name, "invalid_value"); + property_name = processors::RunLlamaCppInference::TopP.name; + } + SECTION("Invalid value for Min P property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinP.name, "invalid_value"); + property_name = processors::RunLlamaCppInference::MinP.name; + } + + REQUIRE_THROWS_WITH(controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}), + fmt::format("Expected parsable float from RunLlamaCppInference::{}: parsing error: GeneralParsingError (0)", property_name)); +} + +TEST_CASE("Top K property empty and invalid values are handled properly") { + std::optional<int32_t> test_top_k = 0; + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams& sampler_params, const processors::LlamaContextParams&) { + test_top_k = sampler_params.top_k; + return std::make_unique<MockLlamaContext>(); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + SECTION("Empty value for Top K property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK.name, ""); + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + REQUIRE(test_top_k == std::nullopt); + } + SECTION("Invalid value for Top K property") { + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK.name, "invalid_value"); + REQUIRE_THROWS_WITH(controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}), + "Expected parsable int64_t from RunLlamaCppInference::Top K: parsing error: GeneralParsingError (0)"); + } +} + +TEST_CASE("Error handling during generation and applying template") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + + SECTION("Generation fails with error") { + mock_llama_context->setGenerationFailure(); + } + + SECTION("Applying template fails with error") { + mock_llama_context->setApplyTemplateFailure(); + } + + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "/path/to/model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).empty()); + REQUIRE(results.at(processors::RunLlamaCppInference::Failure).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Failure)[0]; + CHECK(controller.plan->getContent(output_flow_file) == "42"); +} + +TEST_CASE("Route flow file to failure when prompt and input data is empty") { + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::make_unique<MockLlamaContext>(); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "/path/to/model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, ""); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).empty()); + REQUIRE(results.at(processors::RunLlamaCppInference::Failure).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Failure)[0]; + CHECK(controller.plan->getContent(output_flow_file).empty()); +} + +TEST_CASE("System prompt is optional") { + auto mock_llama_context = std::make_unique<MockLlamaContext>(); + auto mock_llama_context_ptr = mock_llama_context.get(); + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::move(mock_llama_context); + }); + minifi::test::SingleProcessorTestController controller(std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference")); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::SystemPrompt.name, ""); + + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0]; + CHECK(controller.plan->getContent(output_flow_file) == "Test generated content"); + CHECK(mock_llama_context_ptr->getInput() == "Test input"); + REQUIRE(mock_llama_context_ptr->getMessages().size() == 1); + CHECK(mock_llama_context_ptr->getMessages()[0].role == "user"); + CHECK(mock_llama_context_ptr->getMessages()[0].content == "Input data (or flow file content):\n42\n\nQuestion: What is the answer to life, the universe and everything?"); +} + +TEST_CASE("Test output metrics") { + processors::LlamaContext::testSetProvider( + [&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) { + return std::make_unique<MockLlamaContext>(); + }); + auto processor = std::make_unique<processors::RunLlamaCppInference>("RunLlamaCppInference"); + auto processor_metrics = processor->getMetrics(); + minifi::test::SingleProcessorTestController controller(std::move(processor)); + LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>(); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"); + controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"); + + controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}); + + REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1); + auto prometheus_metrics = processor_metrics->calculateMetrics(); + CHECK(prometheus_metrics[prometheus_metrics.size() - 2].name == "tokens_in"); + CHECK(prometheus_metrics[prometheus_metrics.size() - 2].value == 20); + CHECK(prometheus_metrics[prometheus_metrics.size() - 1].name == "tokens_out"); + CHECK(prometheus_metrics[prometheus_metrics.size() - 1].value == 6); + auto c2_metrics = processor_metrics->serialize(); + CHECK(c2_metrics[0].children[c2_metrics[0].children.size() - 2].name == "TokensIn"); Review Comment: Updated in https://github.com/apache/nifi-minifi-cpp/pull/1903/commits/e93db2c01d6e2e159b287add948f5386319b21b6 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@nifi.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org