Skip to content

Commit

Permalink
[Enterprise Search] Add model management API logic (elastic#172120)
Browse files Browse the repository at this point in the history
## Summary

Adding parts of ML model management API logic:
- Fetch models
- Cached and pollable wrapper for model fetching
- Create model
- Start model

These API logic pieces map to existing API endpoints and are currently
unused. Their purpose is to enable one-click deployment of models within
pipeline configuration.

---------

Co-authored-by: kibanamachine <[email protected]>
  • Loading branch information
demjened and kibanamachine authored Nov 30, 2023
1 parent 508e9da commit 4ab4239
Show file tree
Hide file tree
Showing 8 changed files with 525 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { LogicMounter } from '../../../__mocks__/kea_logic';

import { HttpError, Status } from '../../../../../common/types/api';
import { MlModelDeploymentState } from '../../../../../common/types/ml';

import { MlModel } from '../../../../../common/types/ml';

import {
CachedFetchModelsApiLogic,
CachedFetchModelsApiLogicValues,
} from './cached_fetch_models_api_logic';
import { FetchModelsApiLogic } from './fetch_models_api_logic';

const DEFAULT_VALUES: CachedFetchModelsApiLogicValues = {
data: [],
isInitialLoading: false,
isLoading: false,
modelsData: null,
pollTimeoutId: null,
status: Status.IDLE,
};

const FETCH_MODELS_API_DATA_RESPONSE: MlModel[] = [
{
modelId: 'model_1',
title: 'Model 1',
type: 'ner',
deploymentState: MlModelDeploymentState.NotDeployed,
startTime: 0,
targetAllocationCount: 0,
nodeAllocationCount: 0,
threadsPerAllocation: 0,
isPlaceholder: false,
hasStats: false,
},
];
const FETCH_MODELS_API_ERROR_RESPONSE = {
body: {
error: 'Error while fetching models',
message: 'Error while fetching models',
statusCode: 500,
},
} as HttpError;

jest.useFakeTimers();

describe('TextExpansionCalloutLogic', () => {
const { mount } = new LogicMounter(CachedFetchModelsApiLogic);
const { mount: mountFetchModelsApiLogic } = new LogicMounter(FetchModelsApiLogic);

beforeEach(() => {
jest.clearAllMocks();
mountFetchModelsApiLogic();
mount();
});

describe('listeners', () => {
describe('apiError', () => {
it('sets new polling timeout if a timeout ID is already set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});

jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout');

CachedFetchModelsApiLogic.actions.apiError(FETCH_MODELS_API_ERROR_RESPONSE);

expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled();
});
});

describe('apiSuccess', () => {
it('sets new polling timeout if a timeout ID is already set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});

jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout');

CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);

expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled();
});
});

describe('createPollTimeout', () => {
const duration = 5000;
it('clears polling timeout if it is set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});

jest.spyOn(global, 'clearTimeout');

CachedFetchModelsApiLogic.actions.createPollTimeout(duration);

expect(clearTimeout).toHaveBeenCalledWith('timeout-id');
});
it('sets polling timeout', () => {
jest.spyOn(global, 'setTimeout');
jest.spyOn(CachedFetchModelsApiLogic.actions, 'setTimeoutId');

CachedFetchModelsApiLogic.actions.createPollTimeout(duration);

expect(setTimeout).toHaveBeenCalledWith(expect.any(Function), duration);
expect(CachedFetchModelsApiLogic.actions.setTimeoutId).toHaveBeenCalled();
});
});

describe('startPolling', () => {
it('clears polling timeout if it is set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});

jest.spyOn(global, 'clearTimeout');

CachedFetchModelsApiLogic.actions.startPolling();

expect(clearTimeout).toHaveBeenCalledWith('timeout-id');
});
it('makes API request and sets polling timeout', () => {
jest.spyOn(CachedFetchModelsApiLogic.actions, 'makeRequest');
jest.spyOn(CachedFetchModelsApiLogic.actions, 'createPollTimeout');

CachedFetchModelsApiLogic.actions.startPolling();

expect(CachedFetchModelsApiLogic.actions.makeRequest).toHaveBeenCalled();
expect(CachedFetchModelsApiLogic.actions.createPollTimeout).toHaveBeenCalled();
});
});

describe('stopPolling', () => {
it('clears polling timeout if it is set', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});

jest.spyOn(global, 'clearTimeout');

CachedFetchModelsApiLogic.actions.stopPolling();

expect(clearTimeout).toHaveBeenCalledWith('timeout-id');
});
it('clears polling timeout value', () => {
jest.spyOn(CachedFetchModelsApiLogic.actions, 'clearPollTimeout');

CachedFetchModelsApiLogic.actions.stopPolling();

expect(CachedFetchModelsApiLogic.actions.clearPollTimeout).toHaveBeenCalled();
});
});
});

describe('reducers', () => {
describe('modelsData', () => {
it('gets cleared on API reset', () => {
mount({
...DEFAULT_VALUES,
modelsData: [],
});

CachedFetchModelsApiLogic.actions.apiReset();

expect(CachedFetchModelsApiLogic.values.modelsData).toBe(null);
});
it('gets set on API success', () => {
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);

expect(CachedFetchModelsApiLogic.values.modelsData).toEqual(FETCH_MODELS_API_DATA_RESPONSE);
});
});

describe('pollTimeoutId', () => {
it('gets cleared on clear timeout action', () => {
mount({
...DEFAULT_VALUES,
pollTimeoutId: 'timeout-id',
});

CachedFetchModelsApiLogic.actions.clearPollTimeout();

expect(CachedFetchModelsApiLogic.values.pollTimeoutId).toBe(null);
});
it('gets set on set timeout action', () => {
const timeout = setTimeout(() => {}, 500);

CachedFetchModelsApiLogic.actions.setTimeoutId(timeout);

expect(CachedFetchModelsApiLogic.values.pollTimeoutId).toEqual(timeout);
});
});
});

describe('selectors', () => {
describe('isInitialLoading', () => {
it('true if API is idle', () => {
mount(DEFAULT_VALUES);

expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(true);
});
it('true if API is loading for the first time', () => {
mount({
...DEFAULT_VALUES,
status: Status.LOADING,
});

expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(true);
});
it('false if the API is neither idle nor loading', () => {
CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE);

expect(CachedFetchModelsApiLogic.values.isInitialLoading).toBe(false);
});
});
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { kea, MakeLogicType } from 'kea';

import { isEqual } from 'lodash';

import { Status } from '../../../../../common/types/api';
import { MlModel } from '../../../../../common/types/ml';
import { Actions } from '../../../shared/api_logic/create_api_logic';

import { FetchModelsApiLogic, FetchModelsApiResponse } from './fetch_models_api_logic';

const FETCH_MODELS_POLLING_DURATION = 5000; // 5 seconds
const FETCH_MODELS_POLLING_DURATION_ON_FAILURE = 30000; // 30 seconds

export interface CachedFetchModlesApiLogicActions {
apiError: Actions<{}, FetchModelsApiResponse>['apiError'];
apiReset: Actions<{}, FetchModelsApiResponse>['apiReset'];
apiSuccess: Actions<{}, FetchModelsApiResponse>['apiSuccess'];
clearPollTimeout(): void;
createPollTimeout(duration: number): { duration: number };
makeRequest: Actions<{}, FetchModelsApiResponse>['makeRequest'];
setTimeoutId(id: NodeJS.Timeout): { id: NodeJS.Timeout };
startPolling(): void;
stopPolling(): void;
}

export interface CachedFetchModelsApiLogicValues {
data: FetchModelsApiResponse;
isInitialLoading: boolean;
isLoading: boolean;
modelsData: MlModel[] | null;
pollTimeoutId: NodeJS.Timeout | null;
status: Status;
}

export const CachedFetchModelsApiLogic = kea<
MakeLogicType<CachedFetchModelsApiLogicValues, CachedFetchModlesApiLogicActions>
>({
actions: {
clearPollTimeout: true,
createPollTimeout: (duration) => ({ duration }),
setTimeoutId: (id) => ({ id }),
startPolling: true,
stopPolling: true,
},
connect: {
actions: [FetchModelsApiLogic, ['apiSuccess', 'apiError', 'apiReset', 'makeRequest']],
values: [FetchModelsApiLogic, ['data', 'status']],
},
events: ({ values }) => ({
beforeUnmount: () => {
if (values.pollTimeoutId) {
clearTimeout(values.pollTimeoutId);
}
},
}),
listeners: ({ actions, values }) => ({
apiError: () => {
if (values.pollTimeoutId) {
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION_ON_FAILURE);
}
},
apiSuccess: () => {
if (values.pollTimeoutId) {
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION);
}
},
createPollTimeout: ({ duration }) => {
if (values.pollTimeoutId) {
clearTimeout(values.pollTimeoutId);
}

const timeoutId = setTimeout(() => {
actions.makeRequest({});
}, duration);
actions.setTimeoutId(timeoutId);
},
startPolling: () => {
if (values.pollTimeoutId) {
clearTimeout(values.pollTimeoutId);
}
actions.makeRequest({});
actions.createPollTimeout(FETCH_MODELS_POLLING_DURATION);
},
stopPolling: () => {
if (values.pollTimeoutId) {
clearTimeout(values.pollTimeoutId);
}
actions.clearPollTimeout();
},
}),
path: ['enterprise_search', 'content', 'api', 'fetch_models_api_wrapper'],
reducers: {
modelsData: [
null,
{
apiReset: () => null,
apiSuccess: (currentState, newState) =>
isEqual(currentState, newState) ? currentState : newState,
},
],
pollTimeoutId: [
null,
{
clearPollTimeout: () => null,
setTimeoutId: (_, { id }) => id,
},
],
},
selectors: ({ selectors }) => ({
isInitialLoading: [
() => [selectors.status, selectors.modelsData],
(
status: CachedFetchModelsApiLogicValues['status'],
modelsData: CachedFetchModelsApiLogicValues['modelsData']
) => status === Status.IDLE || (modelsData === null && status === Status.LOADING),
],
}),
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { mockHttpValues } from '../../../__mocks__/kea_logic';

import { nextTick } from '@kbn/test-jest-helpers';

import { createModel } from './create_model_api_logic';

describe('CreateModelApiLogic', () => {
const { http } = mockHttpValues;
beforeEach(() => {
jest.clearAllMocks();
});
describe('createModel', () => {
it('calls correct api', async () => {
const mockResponseBody = { modelId: 'model_1', deploymentState: '' };
http.post.mockReturnValue(Promise.resolve(mockResponseBody));

const result = createModel({ modelId: 'model_1' });
await nextTick();
expect(http.post).toHaveBeenCalledWith('/internal/enterprise_search/ml/models/model_1');
await expect(result).resolves.toEqual(mockResponseBody);
});
});
});
Loading

0 comments on commit 4ab4239

Please sign in to comment.