Skip to content

Commit

Permalink
add EmbeddingScheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
Sidsector9 committed Jul 5, 2024
1 parent b7d6d21 commit 9a3645c
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 99 deletions.
126 changes: 126 additions & 0 deletions includes/Classifai/EmbeddingsScheduler.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
<?php

namespace Classifai;

use ActionScheduler_Action;
use ActionScheduler_DBLogger;
use ActionScheduler_Store;

class EmbeddingsScheduler {
/**
* The name of the job.
*
* @var string
*/
private $job_name = '';

/**
* The name of the provider.
*
* @var string
*/
private $provider_name = '';

/**
* EmbeddingsScheduler constructor.
*
* @param string $job_name The name of the job.
*/
public function __construct( $job_name = '', $provider_name = '' ) {
$this->job_name = $job_name;
$this->provider_name = $provider_name;
}

/**
* Initialize the class.
*/
public function init() {
add_filter( 'heartbeat_send', [ $this, 'check_embedding_generation_status' ] );
add_action( 'classifai_before_feature_nav', [ $this, 'render_embeddings_generation_status' ] );
add_action( 'action_scheduler_after_execute', [ $this, 'log_failed_embeddings' ], 10, 2 );
}

/**
* Check if embeddings generation is in progress.
*
* @return bool
*/
public function is_embeddings_generation_in_progress(): bool {
if ( ! class_exists( 'ActionScheduler_Store' ) ) {
return false;
}

$store = ActionScheduler_Store::instance();

$action_id = $store->find_action(
$this->job_name,
array(
'status' => ActionScheduler_Store::STATUS_PENDING,
)
);

return ! empty( $action_id );
}

/**
* Render the embeddings generation status notice.
*/
public function render_embeddings_generation_status() {
if ( ! $this->is_embeddings_generation_in_progress() ) {
return;
}

?>
<div class="notice notice-info classifai-classification-embeddings-message">
<p>
<?php
printf(
'<strong>%1$s</strong>: %2$s',
$this->provider_name,

Check failure on line 79 in includes/Classifai/EmbeddingsScheduler.php

View workflow job for this annotation

GitHub Actions / vipcs

All output should be run through an escaping function (see the Security sections in the WordPress Developer Handbooks), found '$this'.
esc_html__( 'Generation of embeddings is in progress.', 'classifai' )
)
?>
</p>
</div>
<?php
}

/**
* AJAX callback to check the status of embeddings generation.
*
* @param array $response The heartbeat response.
* @return array
*/
public function check_embedding_generation_status( $response ) {
$response['classifaiEmbedInProgress'] = $this->is_embeddings_generation_in_progress();

return $response;
}

/**
* Logs failed embeddings.
*
* @param int $action_id The action ID.
* @param ActionScheduler_Action $action The action object.
*/
public function log_failed_embeddings( $action_id, $action ) {
if ( $this->job_name !== $action->get_hook() ) {
return;
}

$args = $action->get_args();

if ( ! isset( $args['args'] ) && ! isset( $args['args']['exclude'] ) ) {
return;
}

$excludes = $args['args']['exclude'];

if ( empty( $excludes ) || ( 1 === count( $excludes ) && in_array( 1, $excludes, true ) ) ) {
return;
}

$logger = new ActionScheduler_DBLogger();
$logger->log( $action_id, sprintf( 'Embeddings failed for terms: %s', implode( ', ', $excludes ) ) );
}
}
113 changes: 106 additions & 7 deletions includes/Classifai/Providers/Azure/Embeddings.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
use Classifai\Normalizer;
use Classifai\Features\Classification;
use Classifai\Features\Feature;
use Classifai\EmbeddingsScheduler;
use WP_Error;

class Embeddings extends OpenAI {
Expand Down Expand Up @@ -57,6 +58,13 @@ class Embeddings extends OpenAI {
*/
public $nlu_features = [];

/**
* Scheduler instance.
*
* @var EmbeddingsScheduler|null
*/
private static $scheduler_instance = null;

/**
* OpenAI Embeddings constructor.
*
Expand Down Expand Up @@ -162,6 +170,13 @@ public function register() {

$feature = new Classification();

self::$scheduler_instance = new EmbeddingsScheduler(
'classifai_schedule_generate_azure_embedding_job',
__( 'Azure OpenAI Embeddings', 'classifai' )
);
self::$scheduler_instance->init();
add_action( 'classifai_schedule_generate_azure_embedding_job', [ $this, 'generate_embedding_job' ], 10, 4 );

if (
! $feature->is_feature_enabled() ||
$feature->get_feature_provider_instance()::ID !== static::ID
Expand Down Expand Up @@ -670,12 +685,23 @@ private function get_embeddings_similarity( array $embedding, bool $consider_thr
}

/**
* Generate embedding data for all terms within a taxonomy.
* Schedules the job to generate embedding data for all terms within a taxonomy.
*
* @param string $taxonomy Taxonomy slug.
* @param bool $all Whether to generate embeddings for all terms or just those without embeddings.
* @param array $args Overrideable query args for get_terms()
* @param int $user_id The user ID to run this as.
*/
private function trigger_taxonomy_update( string $taxonomy = '', bool $all = false ) {
private function trigger_taxonomy_update( string $taxonomy = '', bool $all = false, array $args = [], int $user_id = 0 ) {
$feature = new Classification();

if (
! $feature->is_feature_enabled() ||
$feature->get_feature_provider_instance()::ID !== static::ID
) {
return;
}

$exclude = [];

// Exclude the uncategorized term.
Expand All @@ -686,34 +712,107 @@ private function trigger_taxonomy_update( string $taxonomy = '', bool $all = fal
}
}

$args = [
/**
* Filter the number of terms to process in a batch.
*
* @since 3.1.0
* @hook classifai_azure_openai_embeddings_terms_per_job
*
* @param {int} $number Number of terms to process per job.
*
* @return {int} Filtered number of terms to process per job.
*/
$number = apply_filters( 'classifai_azure_openai_embeddings_terms_per_job', 100 );

$default_args = [
'taxonomy' => $taxonomy,
'orderby' => 'count',
'order' => 'DESC',
'hide_empty' => false,
'fields' => 'ids',
'meta_key' => 'classifai_azure_openai_embeddings', // phpcs:ignore WordPress.DB.SlowDBQuery.slow_db_query_meta_key
'meta_compare' => 'NOT EXISTS',
'number' => $this->get_max_terms(),
'number' => $number,
'offset' => 0,
'exclude' => $exclude, // phpcs:ignore WordPressVIPMinimum.Performance.WPQueryParams.PostNotIn_exclude
];

$default_args = array_merge( $default_args, $args );

// If we want all terms, remove our meta query.
if ( $all ) {
unset( $args['meta_key'], $args['meta_compare'] );
unset( $default_args['meta_key'], $default_args['meta_compare'] );
} else {
unset( $default_args['offset'] );
}

$terms = get_terms( $args );
if ( 0 === $user_id ) {
$user_id = get_current_user_id();
}

$job_args = [
'taxonomy' => $taxonomy,
'all' => $all,
'args' => $default_args,
'user_id' => $user_id,
];

// We return early and don't schedule the job if there are no terms.
if ( ! as_has_scheduled_action( 'classifai_schedule_generate_azure_embedding_job', $job_args ) ) {
$terms = get_terms( $default_args );

if ( is_wp_error( $terms ) || empty( $terms ) ) {
return;
}
}

\as_enqueue_async_action( 'classifai_schedule_generate_azure_embedding_job', $job_args );
}

/**
* Job to generate embedding data for all terms within a taxonomy.
*
* @param string $taxonomy Taxonomy slug.
* @param bool $all Whether to generate embeddings for all terms or just those without embeddings.
* @param array $args Overrideable query args for get_terms()
* @param int $user_id The user ID to run this as.
*/
public function generate_embedding_job( string $taxonomy = '', bool $all = false, array $args = [], int $user_id = 0 ) {

if ( $user_id > 0 ) {
// We set this as current_user_can() fails when this function runs
// under the context of Action Scheduler.
wp_set_current_user( $user_id );
}

$terms = get_terms( $args );
if ( is_wp_error( $terms ) || empty( $terms ) ) {
return;
}

// Re-orders the keys.
$terms = array_values( $terms );
$exclude = [];

// Generate embedding data for each term.
foreach ( $terms as $term_id ) {
/** @var int $term_id */
$this->generate_embeddings_for_term( $term_id, $all );
$has_generated = $this->generate_embeddings_for_term( $term_id, $all );

if ( is_wp_error( $has_generated ) ) {
$exclude[] = $term_id;
}
}

if ( $all && isset( $args['offset'] ) && isset( $args['number'] ) ) {
$args['offset'] = $args['offset'] + $args['number'];
}

if ( ! empty( $exclude ) ) {
$args['exclude'] = array_merge( $args['exclude'], $exclude ); // phpcs:ignore WordPressVIPMinimum.Performance.WPQueryParams.PostNotIn_exclude
}

$this->trigger_taxonomy_update( $taxonomy, $all, $args, $user_id );
}

/**
Expand Down
Loading

0 comments on commit 9a3645c

Please sign in to comment.