Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@jakmro/standardize naming #62

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Classification(reactContext: ReactApplicationContext) :
classificationModel.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

Expand Down
6 changes: 3 additions & 3 deletions android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
return NAME
}

override fun loadModule(modelPath: String, promise: Promise) {
override fun loadModule(modelSource: String, promise: Promise) {
Fetcher.downloadModel(
reactApplicationContext,
modelPath,
modelSource,
) { path, error ->
if (error != null) {
promise.reject(error.message!!, ETError.InvalidModelPath.toString())
promise.reject(error.message!!, ETError.InvalidModelSource.toString())
return@downloadModel
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class StyleTransfer(reactContext: ReactApplicationContext) :
styleTransferModel.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ enum class ETError(val code: Int) {
UndefinedError(0x65),
ModuleNotLoaded(0x66),
FileWriteFailed(0x67),
InvalidModelPath(0xff),
InvalidModelSource(0xff),

// System errors
Ok(0x00),
Expand Down
6 changes: 3 additions & 3 deletions docs/docs/guides/running-llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ React Native ExecuTorch supports Llama 3.2 models, including quantized versions.
In order to load a model into the app, you need to run the following code:

```typescript
import { useLLM, LLAMA3_2_1B_URL } from 'react-native-executorch';
import { useLLM, LLAMA3_2_1B } from 'react-native-executorch';

const llama = useLLM({
modelSource: LLAMA3_2_1B_URL,
modelSource: LLAMA3_2_1B,
tokenizer: require('../assets/tokenizer.bin'),
contextWindowLength: 3,
});
Expand Down Expand Up @@ -91,7 +91,7 @@ In order to send a message to the model, one can use the following code:

```typescript
const llama = useLLM(
modelSource: LLAMA3_2_1B_URL,
modelSource: LLAMA3_2_1B,
tokenizer: require('../assets/tokenizer.bin'),
);

Expand Down
2 changes: 1 addition & 1 deletion examples/computer-vision/screens/ClassificationScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export const ClassificationScreen = ({
);

const model = useClassification({
modulePath: EFFICIENTNET_V2_S,
modelSource: EFFICIENTNET_V2_S,
});

const handleCameraPress = async (isCamera: boolean) => {
Expand Down
2 changes: 1 addition & 1 deletion examples/computer-vision/screens/StyleTransferScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export const StyleTransferScreen = ({
setImageUri: (imageUri: string) => void;
}) => {
const model = useStyleTransfer({
modulePath: STYLE_TRANSFER_CANDY,
modelSource: STYLE_TRANSFER_CANDY,
});

const handleCameraPress = async (isCamera: boolean) => {
Expand Down
4 changes: 2 additions & 2 deletions examples/llama/screens/ChatScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { SafeAreaView } from 'react-native-safe-area-context';
import SWMIcon from '../assets/icons/swm_icon.svg';
import SendIcon from '../assets/icons/send_icon.svg';
import Spinner from 'react-native-loading-spinner-overlay';
import { LLAMA3_2_1B_QLORA_URL, useLLM } from 'react-native-executorch';
import { LLAMA3_2_1B_QLORA, useLLM } from 'react-native-executorch';
import PauseIcon from '../assets/icons/pause_icon.svg';
import ColorPalette from '../colors';
import Messages from '../components/Messages';
Expand All @@ -25,7 +25,7 @@ export default function ChatScreen() {
const [isTextInputFocused, setIsTextInputFocused] = useState(false);
const [userInput, setUserInput] = useState('');
const llama = useLLM({
modelSource: LLAMA3_2_1B_QLORA_URL,
modelSource: LLAMA3_2_1B_QLORA,
tokenizerSource: require('../assets/tokenizer.bin'),
contextWindowLength: 6,
});
Expand Down
2 changes: 1 addition & 1 deletion ios/RnExecutorch/models/BaseModel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ - (void)loadModel:(NSURL *)modelURL completion:(void (^)(BOOL success, NSNumber*
module = [[ETModel alloc] init];
[Fetcher fetchResource:modelURL resourceType:ResourceType::MODEL completionHandler:^(NSString *filePath, NSError *error) {
if (error) {
completion(NO, @(InvalidModelPath));
completion(NO, @(InvalidModelSource));
return;
}
NSNumber *result = [self->module loadModel: filePath];
Expand Down
2 changes: 1 addition & 1 deletion ios/RnExecutorch/utils/ETError.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ typedef NS_ENUM(NSUInteger, ETError) {
UndefinedError = 0x65,
ModuleNotLoaded = 0x66,
FileWriteFailed = 0x67,
InvalidModelPath = 0xff,
InvalidModelSource = 0xff,

Ok = 0x00,
Internal = 0x01,
Expand Down
12 changes: 6 additions & 6 deletions src/ETModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@ const getTypeIdentifier = (arr: ETInput): number => {
};

interface Props {
modulePath: string | number;
modelSource: string | number;
}

export const useExecutorchModule = ({
modulePath,
modelSource,
}: Props): ExecutorchModule => {
const [error, setError] = useState<string | null>(null);
const [isModelLoading, setIsModelLoading] = useState(true);
const [isModelGenerating, setIsModelGenerating] = useState(false);

useEffect(() => {
const loadModel = async () => {
let path = modulePath;
if (typeof modulePath === 'number') {
path = Image.resolveAssetSource(modulePath).uri;
let path = modelSource;
if (typeof modelSource === 'number') {
path = Image.resolveAssetSource(modelSource).uri;
}

try {
Expand All @@ -42,7 +42,7 @@ export const useExecutorchModule = ({
}
};
loadModel();
}, [modulePath]);
}, [modelSource]);

const forward = async (input: ETInput, shape: number[]) => {
if (isModelLoading) {
Expand Down
2 changes: 1 addition & 1 deletion src/Error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export enum ETError {
ModuleNotLoaded = 0x66,
FileWriteFailed = 0x67,
ModelGenerating = 0x68,
InvalidModelPath = 0xff,
InvalidModelSource = 0xff,

// ExecuTorch mapped errors
// Based on: https://github.com/pytorch/executorch/blob/main/runtime/core/error.h
Expand Down
12 changes: 6 additions & 6 deletions src/StyleTransfer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { StyleTransfer } from './native/RnExecutorchModules';
import { ETError, getError } from './Error';

interface Props {
modulePath: string | number;
modelSource: string | number;
}

interface StyleTransferModule {
Expand All @@ -15,18 +15,18 @@ interface StyleTransferModule {
}

export const useStyleTransfer = ({
modulePath,
modelSource,
}: Props): StyleTransferModule => {
const [error, setError] = useState<null | string>(null);
const [isModelReady, setIsModelReady] = useState(false);
const [isModelGenerating, setIsModelGenerating] = useState(false);

useEffect(() => {
const loadModel = async () => {
let path = modulePath;
let path = modelSource;

if (typeof modulePath === 'number') {
path = Image.resolveAssetSource(modulePath).uri;
if (typeof modelSource === 'number') {
path = Image.resolveAssetSource(modelSource).uri;
}

try {
Expand All @@ -39,7 +39,7 @@ export const useStyleTransfer = ({
};

loadModel();
}, [modulePath]);
}, [modelSource]);

const forward = async (input: string) => {
if (!isModelReady) {
Expand Down
34 changes: 20 additions & 14 deletions src/constants/modelUrls.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import { Platform } from 'react-native';

// LLM's
export const LLAMA3_2_3B_URL =
export const LLAMA3_2_3B =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-3B/original/llama3_2_3B_bf16.pte';
export const LLAMA3_2_3B_QLORA_URL =
export const LLAMA3_2_3B_QLORA =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-3B/QLoRA/llama3_2-3B_qat_lora.pte';
export const LLAMA3_2_3B_SPINQUANT_URL =
export const LLAMA3_2_3B_SPINQUANT =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-3B/spinquant/llama3_2_3B_spinquant.pte';
export const LLAMA3_2_1B_URL =
export const LLAMA3_2_1B =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/original/llama3_2_bf16.pte';
export const LLAMA3_2_1B_QLORA_URL =
export const LLAMA3_2_1B_QLORA =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/QLoRA/llama3_2_qat_lora.pte';
export const LLAMA3_2_1B_SPINQUANT_URL =
export const LLAMA3_2_1B_SPINQUANT =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/spinquant/llama3_2_spinquant.pte';
export const LLAMA3_2_1B_TOKENIZER =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/original/tokenizer.bin';
Expand All @@ -29,17 +29,23 @@ export const STYLE_TRANSFER_CANDY =
Platform.OS === 'ios'
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-candy/resolve/v0.2.0/coreml/style_transfer_candy_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-candy/resolve/v0.2.0/xnnpack/style_transfer_candy_xnnpack.pte';

export const STYLE_TRANSFER_MOSAIC =
Platform.OS === 'ios'
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/main/coreml/style_transfer_mosaic_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/main/xnnpack/style_transfer_mosaic_xnnpack.pte';

? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/v0.2.0/coreml/style_transfer_mosaic_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/v0.2.0/xnnpack/style_transfer_mosaic_xnnpack.pte';
export const STYLE_TRANSFER_RAIN_PRINCESS =
Platform.OS === 'ios'
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/main/coreml/style_transfer_rain_princess_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/main/xnnpack/style_transfer_rain_princess_xnnpack.pte';
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/v0.2.0/coreml/style_transfer_rain_princess_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/v0.2.0/xnnpack/style_transfer_rain_princess_xnnpack.pte';
export const STYLE_TRANSFER_UDNIE =
Platform.OS === 'ios'
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/main/coreml/style_transfer_udnie_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/main/xnnpack/style_transfer_udnie_xnnpack.pte';
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/coreml/style_transfer_udnie_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/xnnpack/style_transfer_udnie_xnnpack.pte';

// Backward compatibility
export const LLAMA3_2_3B_URL = LLAMA3_2_3B;
export const LLAMA3_2_3B_QLORA_URL = LLAMA3_2_3B_QLORA;
export const LLAMA3_2_3B_SPINQUANT_URL = LLAMA3_2_3B_SPINQUANT;
export const LLAMA3_2_1B_URL = LLAMA3_2_1B;
export const LLAMA3_2_1B_QLORA_URL = LLAMA3_2_1B_QLORA;
export const LLAMA3_2_1B_SPINQUANT_URL = LLAMA3_2_1B_SPINQUANT;
12 changes: 6 additions & 6 deletions src/models/Classification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { Classification } from '../native/RnExecutorchModules';
import { ETError, getError } from '../Error';

interface Props {
modulePath: string | number;
modelSource: string | number;
}

interface ClassificationModule {
Expand All @@ -15,18 +15,18 @@ interface ClassificationModule {
}

export const useClassification = ({
modulePath,
modelSource,
}: Props): ClassificationModule => {
const [error, setError] = useState<null | string>(null);
const [isModelReady, setIsModelReady] = useState(false);
const [isModelGenerating, setIsModelGenerating] = useState(false);

useEffect(() => {
const loadModel = async () => {
let path = modulePath;
let path = modelSource;

if (typeof modulePath === 'number') {
path = Image.resolveAssetSource(modulePath).uri;
if (typeof modelSource === 'number') {
path = Image.resolveAssetSource(modelSource).uri;
}

try {
Expand All @@ -40,7 +40,7 @@ export const useClassification = ({
};

loadModel();
}, [modulePath]);
}, [modelSource]);

const forward = async (input: string) => {
if (!isModelReady) {
Expand Down