소스 검색

feat: improve the smart edit user-experience

Lucas.Xu 2 년 전
부모
커밋
653f88033e

+ 2 - 1
frontend/appflowy_flutter/assets/translations/en.json

@@ -138,7 +138,8 @@
     "keep": "Keep",
     "tryAgain": "Try again",
     "discard": "Discard",
-    "replace": "Replace"
+    "replace": "Replace",
+    "insertBelow": "Insert Below"
   },
   "label": {
     "welcome": "Welcome!",

+ 17 - 8
frontend/appflowy_flutter/lib/plugins/document/presentation/plugins/openai/service/openai_client.dart

@@ -1,7 +1,6 @@
 import 'dart:convert';
 
 import 'package:appflowy/plugins/document/presentation/plugins/openai/service/text_edit.dart';
-import 'package:appflowy_editor/appflowy_editor.dart';
 
 import 'text_completion.dart';
 import 'package:dartz/dartz.dart';
@@ -125,6 +124,7 @@ class HttpOpenAIRepository implements OpenAIRepository {
     String? suffix,
     int maxTokens = 2048,
     double temperature = 0.3,
+    bool useAction = false,
   }) async {
     final parameters = {
       'model': 'text-davinci-003',
@@ -151,14 +151,22 @@ class HttpOpenAIRepository implements OpenAIRepository {
           .transform(const Utf8Decoder())
           .transform(const LineSplitter())) {
         syntax += 1;
-        if (syntax == 3) {
-          await onStart();
-          continue;
-        } else if (syntax < 3) {
-          continue;
+        if (!useAction) {
+          if (syntax == 3) {
+            await onStart();
+            continue;
+          } else if (syntax < 3) {
+            continue;
+          }
+        } else {
+          if (syntax == 2) {
+            await onStart();
+            continue;
+          } else if (syntax < 2) {
+            continue;
+          }
         }
         final data = chunk.trim().split('data: ');
-        Log.editor.info(data.toString());
         if (data.length > 1) {
           if (data[1] != '[DONE]') {
             final response = TextCompletionResponse.fromJson(
@@ -173,7 +181,7 @@ class HttpOpenAIRepository implements OpenAIRepository {
               previousSyntax = response.choices.first.text;
             }
           } else {
-            onEnd();
+            await onEnd();
           }
         }
       }
@@ -183,6 +191,7 @@ class HttpOpenAIRepository implements OpenAIRepository {
         OpenAIError.fromJson(json.decode(body)['error']),
       );
     }
+    return;
   }
 
   @override

+ 20 - 1
frontend/appflowy_flutter/lib/plugins/document/presentation/plugins/openai/widgets/smart_edit_action.dart

@@ -10,11 +10,30 @@ enum SmartEditAction {
   String get toInstruction {
     switch (this) {
       case SmartEditAction.summarize:
-        return 'Make this shorter and more concise:';
+        return 'Tl;dr';
       case SmartEditAction.fixSpelling:
         return 'Correct this to standard English:';
     }
   }
+
+  String prompt(String input) {
+    switch (this) {
+      case SmartEditAction.summarize:
+        return '$input\n\nTl;dr';
+      case SmartEditAction.fixSpelling:
+        return 'Correct this to standard English:\n\n$input';
+    }
+  }
+
+  static SmartEditAction from(int index) {
+    switch (index) {
+      case 0:
+        return SmartEditAction.summarize;
+      case 1:
+        return SmartEditAction.fixSpelling;
+    }
+    return SmartEditAction.fixSpelling;
+  }
 }
 
 class SmartEditActionWrapper extends ActionCell {

+ 102 - 61
frontend/appflowy_flutter/lib/plugins/document/presentation/plugins/openai/widgets/smart_edit_node_widget.dart

@@ -1,6 +1,4 @@
-import 'package:appflowy/plugins/document/presentation/plugins/openai/service/error.dart';
 import 'package:appflowy/plugins/document/presentation/plugins/openai/service/openai_client.dart';
-import 'package:appflowy/plugins/document/presentation/plugins/openai/service/text_edit.dart';
 import 'package:appflowy/plugins/document/presentation/plugins/openai/util/learn_more_action.dart';
 import 'package:appflowy/plugins/document/presentation/plugins/openai/widgets/smart_edit_action.dart';
 import 'package:appflowy/user/application/user_service.dart';
@@ -12,8 +10,6 @@ import 'package:flutter/material.dart';
 import 'package:appflowy/generated/locale_keys.g.dart';
 import 'package:easy_localization/easy_localization.dart';
 import 'package:http/http.dart' as http;
-import 'package:dartz/dartz.dart' as dartz;
-import 'package:appflowy/util/either_extension.dart';
 
 const String kSmartEditType = 'smart_edit_input';
 const String kSmartEditInstructionType = 'smart_edit_instruction';
@@ -22,9 +18,9 @@ const String kSmartEditInputType = 'smart_edit_input';
 class SmartEditInputBuilder extends NodeWidgetBuilder<Node> {
   @override
   NodeValidator<Node> get nodeValidator => (node) {
-        return SmartEditAction.values.map((e) => e.toInstruction).contains(
-                  node.attributes[kSmartEditInstructionType],
-                ) &&
+        return SmartEditAction.values
+                .map((e) => e.index)
+                .contains(node.attributes[kSmartEditInstructionType]) &&
             node.attributes[kSmartEditInputType] is String;
       };
 
@@ -53,13 +49,14 @@ class _SmartEditInput extends StatefulWidget {
 }
 
 class _SmartEditInputState extends State<_SmartEditInput> {
-  String get instruction => widget.node.attributes[kSmartEditInstructionType];
+  SmartEditAction get action =>
+      SmartEditAction.from(widget.node.attributes[kSmartEditInstructionType]);
   String get input => widget.node.attributes[kSmartEditInputType];
 
   final focusNode = FocusNode();
   final client = http.Client();
-  dartz.Either<OpenAIError, TextEditResponse>? result;
   bool loading = true;
+  String result = '';
 
   @override
   void initState() {
@@ -72,12 +69,7 @@ class _SmartEditInputState extends State<_SmartEditInput> {
         widget.editorState.service.keyboardService?.enable();
       }
     });
-    _requestEdits().then(
-      (value) => setState(() {
-        result = value;
-        loading = false;
-      }),
-    );
+    _requestCompletions();
   }
 
   @override
@@ -141,25 +133,14 @@ class _SmartEditInputState extends State<_SmartEditInput> {
         child: const CircularProgressIndicator(),
       ),
     );
-    if (result == null) {
+    if (result.isEmpty) {
       return loading;
     }
-    return result!.fold((error) {
-      return Flexible(
-        child: Text(
-          error.message,
-          style: Theme.of(context).textTheme.bodyMedium?.copyWith(
-                color: Colors.red,
-              ),
-        ),
-      );
-    }, (response) {
-      return Flexible(
-        child: Text(
-          response.choices.map((e) => e.text).join('\n'),
-        ),
-      );
-    });
+    return Flexible(
+      child: Text(
+        result,
+      ),
+    );
   }
 
   Widget _buildInputFooterWidget(BuildContext context) {
@@ -174,8 +155,23 @@ class _SmartEditInputState extends State<_SmartEditInput> {
               ),
             ],
           ),
-          onPressed: () {
-            _onReplace();
+          onPressed: () async {
+            await _onReplace();
+            _onExit();
+          },
+        ),
+        const Space(10, 0),
+        FlowyRichTextButton(
+          TextSpan(
+            children: [
+              TextSpan(
+                text: LocaleKeys.button_insertBelow.tr(),
+                style: Theme.of(context).textTheme.bodyMedium,
+              ),
+            ],
+          ),
+          onPressed: () async {
+            await _onInsertBelow();
             _onExit();
           },
         ),
@@ -201,12 +197,11 @@ class _SmartEditInputState extends State<_SmartEditInput> {
     final selectedNodes = widget
         .editorState.service.selectionService.currentSelectedNodes.normalized
         .whereType<TextNode>();
-    if (selection == null || result == null || result!.isLeft()) {
+    if (selection == null || result.isEmpty) {
       return;
     }
 
-    final texts = result!.asRight().choices.first.text.split('\n')
-      ..removeWhere((element) => element.isEmpty);
+    final texts = result.split('\n')..removeWhere((element) => element.isEmpty);
     final transaction = widget.editorState.transaction;
     transaction.replaceTexts(
       selectedNodes.toList(growable: false),
@@ -216,6 +211,25 @@ class _SmartEditInputState extends State<_SmartEditInput> {
     return widget.editorState.apply(transaction);
   }
 
+  Future<void> _onInsertBelow() async {
+    final selection = widget.editorState.service.selectionService
+        .currentSelection.value?.normalized;
+    if (selection == null || result.isEmpty) {
+      return;
+    }
+    final texts = result.split('\n')..removeWhere((element) => element.isEmpty);
+    final transaction = widget.editorState.transaction;
+    transaction.insertNodes(
+      selection.normalized.end.path.next,
+      texts.map(
+        (e) => TextNode(
+          delta: Delta()..insert(e),
+        ),
+      ),
+    );
+    return widget.editorState.apply(transaction);
+  }
+
   Future<void> _onExit() async {
     final transaction = widget.editorState.transaction;
     transaction.deleteNode(widget.node);
@@ -228,35 +242,62 @@ class _SmartEditInputState extends State<_SmartEditInput> {
     );
   }
 
-  Future<dartz.Either<OpenAIError, TextEditResponse>> _requestEdits() async {
+  Future<void> _requestCompletions() async {
     final result = await UserBackendService.getCurrentUserProfile();
-    return result.fold((userProfile) async {
+    return result.fold((l) async {
       final openAIRepository = HttpOpenAIRepository(
         client: client,
-        apiKey: userProfile.openaiKey,
-      );
-      final edits = await openAIRepository.getEdits(
-        input: input,
-        instruction: instruction,
-        n: 1,
+        apiKey: l.openaiKey,
       );
-      return edits.fold((error) async {
-        return dartz.Left(
-          OpenAIError(
-            message:
-                LocaleKeys.document_plugins_smartEditCouldNotFetchResult.tr(),
-          ),
+      var lines = input.split('\n\n');
+      if (action == SmartEditAction.summarize) {
+        lines = [lines.join('\n')];
+      }
+      for (var i = 0; i < lines.length; i++) {
+        final element = lines[i];
+        await openAIRepository.getStreamedCompletions(
+          useAction: true,
+          prompt: action.prompt(element),
+          onStart: () async {
+            setState(() {
+              loading = false;
+            });
+          },
+          onProcess: (response) async {
+            setState(() {
+              this.result += response.choices.first.text;
+            });
+          },
+          onEnd: () async {
+            setState(() {
+              if (i != lines.length - 1) {
+                this.result += '\n';
+              }
+            });
+          },
+          onError: (error) async {
+            await _showError(error.message);
+            await _onExit();
+          },
         );
-      }, (textEdit) async {
-        return dartz.Right(textEdit);
-      });
-    }, (error) async {
-      // error
-      return dartz.Left(
-        OpenAIError(
-          message: LocaleKeys.document_plugins_smartEditCouldNotFetchKey.tr(),
-        ),
-      );
+      }
+    }, (r) async {
+      await _showError(r.msg);
+      await _onExit();
     });
   }
+
+  Future<void> _showError(String message) async {
+    ScaffoldMessenger.of(context).showSnackBar(
+      SnackBar(
+        action: SnackBarAction(
+          label: LocaleKeys.button_Cancel.tr(),
+          onPressed: () {
+            ScaffoldMessenger.of(context).hideCurrentSnackBar();
+          },
+        ),
+        content: FlowyText(message),
+      ),
+    );
+  }
 }

+ 6 - 4
frontend/appflowy_flutter/lib/plugins/document/presentation/plugins/openai/widgets/smart_edit_toolbar_item.dart

@@ -16,8 +16,7 @@ ToolbarItem smartEditItem = ToolbarItem(
   validator: (editorState) {
     // All selected nodes must be text.
     final nodes = editorState.service.selectionService.currentSelectedNodes;
-    return nodes.whereType<TextNode>().length == nodes.length &&
-        nodes.length == 1;
+    return nodes.whereType<TextNode>().length == nodes.length;
   },
   itemBuilder: (context, editorState) {
     return _SmartEditWidget(
@@ -102,14 +101,17 @@ class _SmartEditWidgetState extends State<_SmartEditWidget> {
       textNodes.normalized,
       selection.normalized,
     );
+    while (input.last.isEmpty) {
+      input.removeLast();
+    }
     final transaction = widget.editorState.transaction;
     transaction.insertNode(
       selection.normalized.end.path.next,
       Node(
         type: kSmartEditType,
         attributes: {
-          kSmartEditInstructionType: actionWrapper.inner.toInstruction,
-          kSmartEditInputType: input,
+          kSmartEditInstructionType: actionWrapper.inner.index,
+          kSmartEditInputType: input.join('\n\n'),
         },
       ),
     );

+ 0 - 2
frontend/appflowy_flutter/lib/workspace/presentation/settings/widgets/settings_user_view.dart

@@ -1,6 +1,5 @@
 import 'package:appflowy/startup/startup.dart';
 import 'package:appflowy/util/debounce.dart';
-import 'package:appflowy_backend/log.dart';
 import 'package:flowy_infra/size.dart';
 import 'package:flowy_infra_ui/style_widget/text.dart';
 import 'package:flutter/material.dart';
@@ -133,7 +132,6 @@ class _OpenaiKeyInputState extends State<_OpenaiKeyInput> {
       ),
       onChanged: (value) {
         debounce.call(() {
-          Log.debug('SettingsUserViewBloc');
           context
               .read<SettingsUserViewBloc>()
               .add(SettingsUserEvent.updateUserOpenAIKey(value));

+ 2 - 2
frontend/appflowy_flutter/packages/appflowy_editor/lib/src/commands/command_extension.dart

@@ -52,7 +52,7 @@ extension CommandExtension on EditorState {
     throw Exception('path and textNode cannot be null at the same time');
   }
 
-  String getTextInSelection(
+  List<String> getTextInSelection(
     List<TextNode> textNodes,
     Selection selection,
   ) {
@@ -77,6 +77,6 @@ extension CommandExtension on EditorState {
         }
       }
     }
-    return res.join('\n');
+    return res;
   }
 }

+ 4 - 14
frontend/appflowy_flutter/packages/appflowy_editor/lib/src/core/transform/transaction.dart

@@ -337,7 +337,6 @@ extension TextTransaction on Transaction {
     }
 
     if (textNodes.length > texts.length) {
-      final length = textNodes.length;
       for (var i = 0; i < textNodes.length; i++) {
         final textNode = textNodes[i];
         if (i == 0) {
@@ -347,24 +346,15 @@ extension TextTransaction on Transaction {
             textNode.toPlainText().length,
             texts.first,
           );
-        } else if (i == length - 1) {
+        } else if (i < texts.length - 1) {
           replaceText(
             textNode,
             0,
-            selection.endIndex,
-            texts.last,
+            textNode.toPlainText().length,
+            texts[i],
           );
         } else {
-          if (i < texts.length - 1) {
-            replaceText(
-              textNode,
-              0,
-              textNode.toPlainText().length,
-              texts[i],
-            );
-          } else {
-            deleteNode(textNode);
-          }
+          deleteNode(textNode);
         }
       }
       afterSelection = null;