Bläddra i källkod

feat: refactor the gpt3 api and support multi line completion

Lucas.Xu 2 år sedan
förälder
incheckning
fa0a334d6c

+ 1 - 1
frontend/app_flowy/packages/appflowy_editor/example/lib/pages/simple_editor.dart

@@ -4,7 +4,7 @@ import 'package:appflowy_editor/appflowy_editor.dart';
 import 'package:appflowy_editor_plugins/appflowy_editor_plugins.dart';
 import 'package:example/plugin/AI/continue_to_write.dart';
 import 'package:example/plugin/AI/auto_completion.dart';
-import 'package:example/plugin/AI/getgpt3completions.dart';
+import 'package:example/plugin/AI/gpt3.dart';
 import 'package:example/plugin/AI/smart_edit.dart';
 import 'package:flutter/material.dart';
 

+ 13 - 7
frontend/app_flowy/packages/appflowy_editor/example/lib/plugin/AI/auto_completion.dart

@@ -1,5 +1,5 @@
 import 'package:appflowy_editor/appflowy_editor.dart';
-import 'package:example/plugin/AI/getgpt3completions.dart';
+import 'package:example/plugin/AI/gpt3.dart';
 import 'package:example/plugin/AI/text_robot.dart';
 import 'package:flutter/material.dart';
 import 'package:flutter/services.dart';
@@ -37,12 +37,18 @@ SelectionMenuItem autoCompletionMenuItem = SelectionMenuItem(
                 Navigator.of(context).pop();
                 // fetch the result and insert it
                 final textRobot = TextRobot(editorState: editorState);
-                getGPT3Completion(apiKey, controller.text, '', (result) async {
-                  await textRobot.insertText(
-                    result,
-                    inputType: TextRobotInputType.character,
-                  );
-                });
+                const gpt3 = GPT3APIClient(apiKey: apiKey);
+                gpt3.getGPT3Completion(
+                  controller.text,
+                  '',
+                  onResult: (result) async {
+                    await textRobot.insertText(
+                      result,
+                      inputType: TextRobotInputType.character,
+                    );
+                  },
+                  onError: () async {},
+                );
               } else if (key.logicalKey == LogicalKeyboardKey.escape) {
                 Navigator.of(context).pop();
               }

+ 84 - 23
frontend/app_flowy/packages/appflowy_editor/example/lib/plugin/AI/continue_to_write.dart

@@ -1,5 +1,5 @@
 import 'package:appflowy_editor/appflowy_editor.dart';
-import 'package:example/plugin/AI/getgpt3completions.dart';
+import 'package:example/plugin/AI/gpt3.dart';
 import 'package:example/plugin/AI/text_robot.dart';
 import 'package:flutter/material.dart';
 
@@ -14,35 +14,96 @@ SelectionMenuItem continueToWriteMenuItem = SelectionMenuItem(
   ),
   keywords: ['continue to write'],
   handler: ((editorState, menuService, context) async {
-    // get the current text
+    // Two cases
+    // 1. if there is content in the text node where the cursor is located,
+    //  then we use the current text content as data.
+    // 2. if there is no content in the text node where the cursor is located,
+    // then we use the previous / next text node's content as data.
+
     final selection =
         editorState.service.selectionService.currentSelection.value;
-    final textNodes = editorState.service.selectionService.currentSelectedNodes;
-    if (selection == null || !selection.isCollapsed || textNodes.length != 1) {
+    if (selection == null || !selection.isCollapsed) {
+      return;
+    }
+
+    final textNodes = editorState.service.selectionService.currentSelectedNodes
+        .whereType<TextNode>();
+    if (textNodes.isEmpty) {
       return;
     }
-    final textNode = textNodes.first as TextNode;
-    final prompt = textNode.delta.slice(0, selection.startIndex).toPlainText();
-    final suffix = textNode.delta
-        .slice(
-          selection.endIndex,
-          textNode.toPlainText().length,
-        )
-        .toPlainText();
+
     final textRobot = TextRobot(editorState: editorState);
-    getGPT3Completion(
-      apiKey,
+    const gpt3 = GPT3APIClient(apiKey: apiKey);
+    final textNode = textNodes.first;
+
+    var prompt = '';
+    var suffix = '';
+
+    void continueToWriteInSingleLine() {
+      prompt = textNode.delta.slice(0, selection.startIndex).toPlainText();
+      suffix = textNode.delta
+          .slice(
+            selection.endIndex,
+            textNode.toPlainText().length,
+          )
+          .toPlainText();
+    }
+
+    void continueToWriteInMulitLines() {
+      final parent = textNode.parent;
+      if (parent != null) {
+        for (final node in parent.children) {
+          if (node is! TextNode || node.toPlainText().isEmpty) continue;
+          if (node.path < textNode.path) {
+            prompt += '${node.toPlainText()}\n';
+          } else if (node.path > textNode.path) {
+            suffix += '${node.toPlainText()}\n';
+          }
+        }
+      }
+    }
+
+    if (textNodes.first.toPlainText().isNotEmpty) {
+      continueToWriteInSingleLine();
+    } else {
+      continueToWriteInMulitLines();
+    }
+
+    if (prompt.isEmpty && suffix.isEmpty) {
+      return;
+    }
+
+    late final BuildContext diglogContext;
+
+    showDialog(
+      context: context,
+      builder: (context) {
+        diglogContext = context;
+        return AlertDialog(
+          content: Column(
+            mainAxisSize: MainAxisSize.min,
+            children: const [
+              CircularProgressIndicator(),
+              SizedBox(height: 10),
+              Text('Loading'),
+            ],
+          ),
+        );
+      },
+    );
+
+    gpt3.getGPT3Completion(
       prompt,
       suffix,
-      (result) async {
-        if (result == '\\n') {
-          await editorState.insertNewLineAtCurrentSelection();
-        } else {
-          await textRobot.insertText(
-            result,
-            inputType: TextRobotInputType.word,
-          );
-        }
+      onResult: (result) async {
+        Navigator.of(diglogContext).pop(true);
+        await textRobot.insertText(
+          result,
+          inputType: TextRobotInputType.word,
+        );
+      },
+      onError: () async {
+        Navigator.of(diglogContext).pop(true);
       },
     );
   }),

+ 0 - 111
frontend/app_flowy/packages/appflowy_editor/example/lib/plugin/AI/getgpt3completions.dart

@@ -1,111 +0,0 @@
-import 'package:http/http.dart' as http;
-import 'dart:async';
-import 'dart:convert';
-
-// Please fill in your own API key
-const apiKey = '';
-
-Future<void> getGPT3Completion(
-  String apiKey,
-  String prompt,
-  String suffix,
-  Future<void> Function(String)
-      onData, // callback function to handle streaming data
-  {
-  int maxTokens = 200,
-  double temperature = .3,
-  bool stream = true,
-}) async {
-  final data = {
-    'prompt': prompt,
-    'suffix': suffix,
-    'max_tokens': maxTokens,
-    'temperature': temperature,
-    'stream': stream, // set stream parameter to true
-  };
-
-  final headers = {
-    'Authorization': apiKey,
-    'Content-Type': 'application/json',
-  };
-  final request = http.Request(
-    'POST',
-    Uri.parse('https://api.openai.com/v1/engines/text-davinci-003/completions'),
-  );
-  request.body = json.encode(data);
-  request.headers.addAll(headers);
-
-  final httpResponse = await request.send();
-
-  if (httpResponse.statusCode == 200) {
-    await for (final chunk in httpResponse.stream) {
-      var result = utf8.decode(chunk).split('text": "');
-      var text = '';
-      if (result.length > 1) {
-        result = result[1].split('",');
-        if (result.isNotEmpty) {
-          text = result.first;
-        }
-      }
-
-      final processedText = text
-          .replaceAll('\\r', '\r')
-          .replaceAll('\\t', '\t')
-          .replaceAll('\\b', '\b')
-          .replaceAll('\\f', '\f')
-          .replaceAll('\\v', '\v')
-          .replaceAll('\\\'', '\'')
-          .replaceAll('"', '"')
-          .replaceAll('\\0', '0')
-          .replaceAll('\\1', '1')
-          .replaceAll('\\2', '2')
-          .replaceAll('\\3', '3')
-          .replaceAll('\\4', '4')
-          .replaceAll('\\5', '5')
-          .replaceAll('\\6', '6')
-          .replaceAll('\\7', '7')
-          .replaceAll('\\8', '8')
-          .replaceAll('\\9', '9');
-
-      await onData(processedText);
-    }
-  }
-}
-
-Future<void> getGPT3Edit(
-  String apiKey,
-  String input,
-  String instruction, {
-  required Future<void> Function(List<String> result) onResult,
-  required Future<void> Function() onError,
-  int n = 1,
-  double temperature = .3,
-}) async {
-  final data = {
-    'model': 'text-davinci-edit-001',
-    'input': input,
-    'instruction': instruction,
-    'temperature': temperature,
-    'n': n,
-  };
-
-  final headers = {
-    'Authorization': apiKey,
-    'Content-Type': 'application/json',
-  };
-
-  var response = await http.post(
-    Uri.parse('https://api.openai.com/v1/edits'),
-    headers: headers,
-    body: json.encode(data),
-  );
-  if (response.statusCode == 200) {
-    final result = json.decode(response.body);
-    final choices = result['choices'];
-    if (choices != null && choices is List) {
-      onResult(choices.map((e) => e['text'] as String).toList());
-    }
-  } else {
-    onError();
-  }
-}

+ 119 - 0
frontend/app_flowy/packages/appflowy_editor/example/lib/plugin/AI/gpt3.dart

@@ -0,0 +1,119 @@
+import 'package:http/http.dart' as http;
+import 'dart:async';
+import 'dart:convert';
+
+// Please fill in your own API key
+const apiKey = '';
+
+enum GPT3API {
+  completion,
+  edit,
+}
+
+extension on GPT3API {
+  Uri get uri {
+    switch (this) {
+      case GPT3API.completion:
+        return Uri.parse('https://api.openai.com/v1/completions');
+      case GPT3API.edit:
+        return Uri.parse('https://api.openai.com/v1/edits');
+    }
+  }
+}
+
+class GPT3APIClient {
+  const GPT3APIClient({
+    required this.apiKey,
+  });
+
+  final String apiKey;
+
+  /// Get completions from GPT-3
+  ///
+  /// [prompt] is the prompt text
+  /// [suffix] is the suffix text
+  /// [onResult] is the callback function to handle the result
+  /// [maxTokens] is the maximum number of tokens to generate
+  /// [temperature] is the temperature of the model
+  ///
+  /// See https://beta.openai.com/docs/api-reference/completions/create
+  Future<void> getGPT3Completion(
+    String prompt,
+    String suffix, {
+    required Future<void> Function(String result) onResult,
+    required Future<void> Function() onError,
+    int maxTokens = 200,
+    double temperature = .3,
+  }) async {
+    final data = {
+      'model': 'text-davinci-003',
+      'prompt': prompt,
+      'suffix': suffix,
+      'max_tokens': maxTokens,
+      'temperature': temperature,
+      'stream': false,
+    };
+
+    final headers = {
+      'Authorization': apiKey,
+      'Content-Type': 'application/json',
+    };
+
+    final response = await http.post(
+      GPT3API.completion.uri,
+      headers: headers,
+      body: json.encode(data),
+    );
+
+    if (response.statusCode == 200) {
+      final result = json.decode(response.body);
+      final choices = result['choices'];
+      if (choices != null && choices is List) {
+        for (final choice in choices) {
+          final text = choice['text'];
+          await onResult(text);
+        }
+      }
+    } else {
+      await onError();
+    }
+  }
+
+  Future<void> getGPT3Edit(
+    String apiKey,
+    String input,
+    String instruction, {
+    required Future<void> Function(List<String> result) onResult,
+    required Future<void> Function() onError,
+    int n = 1,
+    double temperature = .3,
+  }) async {
+    final data = {
+      'model': 'text-davinci-edit-001',
+      'input': input,
+      'instruction': instruction,
+      'temperature': temperature,
+      'n': n,
+    };
+
+    final headers = {
+      'Authorization': apiKey,
+      'Content-Type': 'application/json',
+    };
+
+    final response = await http.post(
+      Uri.parse('https://api.openai.com/v1/edits'),
+      headers: headers,
+      body: json.encode(data),
+    );
+    if (response.statusCode == 200) {
+      final result = json.decode(response.body);
+      final choices = result['choices'];
+      if (choices != null && choices is List) {
+        await onResult(choices.map((e) => e['text'] as String).toList());
+      }
+    } else {
+      await onError();
+    }
+  }
+}

+ 4 - 2
frontend/app_flowy/packages/appflowy_editor/example/lib/plugin/AI/smart_edit.dart

@@ -1,5 +1,5 @@
 import 'package:appflowy_editor/appflowy_editor.dart';
-import 'package:example/plugin/AI/getgpt3completions.dart';
+import 'package:example/plugin/AI/gpt3.dart';
 import 'package:flutter/material.dart';
 import 'package:flutter/services.dart';
 
@@ -52,6 +52,8 @@ class _SmartEditWidgetState extends State<SmartEditWidget> {
 
   var result = '';
 
+  final gpt3 = const GPT3APIClient(apiKey: apiKey);
+
   Iterable<TextNode> get currentSelectedTextNodes =>
       widget.editorState.service.selectionService.currentSelectedNodes
           .whereType<TextNode>();
@@ -180,7 +182,7 @@ class _SmartEditWidgetState extends State<SmartEditWidget> {
       },
     );
 
-    getGPT3Edit(
+    gpt3.getGPT3Edit(
       apiKey,
       text,
       inputEventController.text,

+ 8 - 5
frontend/app_flowy/packages/appflowy_editor/example/lib/plugin/AI/text_robot.dart

@@ -18,7 +18,7 @@ class TextRobot {
     String text, {
     TextRobotInputType inputType = TextRobotInputType.character,
   }) async {
-    final lines = text.split('\\n');
+    final lines = text.split('\n');
     for (final line in lines) {
       if (line.isEmpty) continue;
       switch (inputType) {
@@ -32,10 +32,13 @@ class TextRobot {
           }
           break;
         case TextRobotInputType.word:
-          await editorState.insertTextAtCurrentSelection(
-            line,
-          );
-          await Future.delayed(delay, () {});
+          final words = line.split(' ').map((e) => '$e ');
+          for (final word in words) {
+            await editorState.insertTextAtCurrentSelection(
+              word,
+            );
+            await Future.delayed(delay, () {});
+          }
           break;
       }
 

+ 1 - 2
frontend/app_flowy/packages/appflowy_editor/lib/src/core/transform/transaction.dart

@@ -171,10 +171,9 @@ extension TextTransaction on Transaction {
 
   void splitText(TextNode textNode, int offset) {
     final delta = textNode.delta;
-    final first = delta.slice(0, offset);
     final second = delta.slice(offset, delta.length);
     final path = textNode.path.next;
-    updateText(textNode, first);
+    deleteText(textNode, offset, delta.length);
     insertNode(
       path,
       TextNode(